Shield Agents
🛡️ Initial release - Shield Agents v1.0.0
de31cf7
Raw
History Blame Contribute Delete
14.5 kB
"""
Command-line interface for Shield Agents.
Provides a beautiful, interactive CLI using the Rich library for
progress bars, tables, and colored output.
"""
import argparse
import asyncio
import json
import logging
import sys
from pathlib import Path
from typing import Optional
from . import __version__
from .config import ShieldConfig
from .orchestrator import Orchestrator
from .shieldignore import ShieldIgnore
from .cache import ScanCache
logger = logging.getLogger("shield_agents.cli")
def setup_logging(verbose: bool = False, debug: bool = False):
"""Configure logging based on verbosity."""
level = logging.DEBUG if debug else (logging.INFO if verbose else logging.WARNING)
logging.basicConfig(
level=level,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
datefmt="%H:%M:%S",
)
def create_parser() -> argparse.ArgumentParser:
"""Create the CLI argument parser."""
parser = argparse.ArgumentParser(
prog="shield-agents",
description="Shield Agents - AI-Powered Multi-Agent Cybersecurity Scanner",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
shield-agents scan ./my-project # Scan a project
shield-agents scan ./my-project --full # Full scan (ignore cache)
shield-agents scan ./my-project --fix # Scan with auto-fix suggestions
shield-agents scan ./my-project --sarif-only # Output only SARIF for GitHub
shield-agents init # Create .shieldignore template
shield-agents cache --clear # Clear scan cache
shield-agents cache --stats # Show cache statistics
shield-agents version # Show version info
""",
)
subparsers = parser.add_subparsers(dest="command", help="Available commands")
# Scan command
scan_parser = subparsers.add_parser("scan", help="Run security scan")
scan_parser.add_argument("target", help="Path to scan")
scan_parser.add_argument("--config", "-c", help="Path to config YAML file")
scan_parser.add_argument("--full", action="store_true", help="Full scan (ignore cache)")
scan_parser.add_argument("--fix", action="store_true", help="Generate auto-fix suggestions")
scan_parser.add_argument("--no-report", action="store_true", help="Skip report generation")
scan_parser.add_argument("--sarif-only", action="store_true", help="Output only SARIF format")
scan_parser.add_argument("--output", "-o", help="Output directory for reports")
scan_parser.add_argument("--provider", choices=["mock", "openai", "anthropic", "ollama"], help="LLM provider")
scan_parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output")
scan_parser.add_argument("--debug", action="store_true", help="Debug output")
scan_parser.add_argument("--no-cache", action="store_true", help="Disable caching")
scan_parser.add_argument("--no-dedup", action="store_true", help="Disable deduplication")
scan_parser.add_argument("--no-ignore", action="store_true", help="Ignore .shieldignore rules")
scan_parser.add_argument("--ci", action="store_true", help="CI/CD mode: SARIF output, JSON to stdout, silent except errors")
scan_parser.add_argument("--format", "-f", choices=["rich", "json", "sarif", "plain"], default="rich", help="Output format (default: rich)")
scan_parser.add_argument("--fail-threshold", type=int, default=75, help="Risk score threshold for CI failure (default: 75)")
# Init command
init_parser = subparsers.add_parser("init", help="Initialize Shield Agents configuration")
init_parser.add_argument("--path", default=".", help="Directory to initialize")
# Cache command
cache_parser = subparsers.add_parser("cache", help="Manage scan cache")
cache_parser.add_argument("--clear", action="store_true", help="Clear the cache")
cache_parser.add_argument("--stats", action="store_true", help="Show cache statistics")
# Version command
subparsers.add_parser("version", help="Show version information")
return parser
async def run_scan_ci(args) -> int:
"""CI/CD optimized scan - silent except for JSON output on stdout."""
config = ShieldConfig(config_path=args.config) if args.config else ShieldConfig()
config.target_path = args.target
# Apply CLI overrides
if args.provider:
config.llm.provider = args.provider
if args.output:
config.report.output_dir = args.output
if args.no_cache:
config.cache.enabled = False
if args.no_dedup:
config.deduplication.enabled = False
# CI mode: always generate SARIF + JSON (or SARIF only if --sarif-only)
config.report.formats = ["sarif", "json"] if not getattr(args, 'sarif_only', False) else ["sarif"]
# Minimal logging for CI
setup_logging(verbose=False, debug=False)
orchestrator = Orchestrator(config)
result = await orchestrator.scan(
target_path=args.target,
full_scan=args.full,
generate_report=True,
generate_fixes=args.fix,
)
# Output JSON to stdout for CI pipeline consumption
output = result.to_dict()
print(json.dumps(output, indent=2))
# Exit code based on risk threshold
threshold = getattr(args, 'fail_threshold', 75) or 75
return 1 if result.risk_score >= threshold else 0
async def run_scan(args) -> int:
"""Execute the scan command."""
# If format is json, use CI-style output
if getattr(args, 'format', 'rich') == 'json':
return await run_scan_ci(args)
try:
from rich.console import Console
from rich.progress import Progress, SpinnerColumn, TextColumn
from rich.table import Table
from rich.panel import Panel
from rich import box
except ImportError:
# Fallback without Rich
return await _run_scan_simple(args)
console = Console()
# Load configuration
config = ShieldConfig(config_path=args.config) if args.config else ShieldConfig()
config.target_path = args.target
# Apply CLI overrides
if args.provider:
config.llm.provider = args.provider
if args.output:
config.report.output_dir = args.output
if args.no_cache:
config.cache.enabled = False
if args.no_dedup:
config.deduplication.enabled = False
if args.no_ignore:
config.shieldignore.enabled = False
if args.sarif_only:
config.report.formats = ["sarif"]
if args.ci:
config.report.formats = ["sarif", "json"]
config.verbose = args.verbose
config.debug = args.debug
setup_logging(args.verbose, args.debug)
# Display header
console.print(Panel.fit(
"[bold blue]Shield Agents[/bold blue] - AI-Powered Security Scanner\n"
f"[dim]Version {__version__} | Provider: {config.llm.provider}[/dim]",
border_style="blue",
))
# Run scan
orchestrator = Orchestrator(config)
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
console=console,
) as progress:
task = progress.add_task("Scanning...", total=None)
result = await orchestrator.scan(
target_path=args.target,
full_scan=args.full,
generate_report=not args.no_report,
generate_fixes=args.fix,
)
progress.update(task, completed=True)
# Display results
severity_colors = {
"CRITICAL": "bold red",
"HIGH": "red",
"MEDIUM": "yellow",
"LOW": "cyan",
"INFO": "dim",
}
# Risk score panel
risk_color = "red" if result.risk_score >= 75 else "yellow" if result.risk_score >= 50 else "green"
console.print(Panel.fit(
f"[{risk_color}]Risk Score: {result.risk_score}/100[/{risk_color}]\n"
f"Files scanned: {result.files_scanned}\n"
f"Total findings: {len(result.filtered_findings)}\n"
f"Scan duration: {result.scan_duration:.2f}s",
title="Scan Summary",
))
# Severity breakdown table
sev_table = Table(title="Severity Breakdown", box=box.SIMPLE)
sev_table.add_column("Severity", style="bold")
sev_table.add_column("Count", justify="right")
severity_counts = {}
for f in result.filtered_findings:
sev = f.get("severity", "MEDIUM").upper()
severity_counts[sev] = severity_counts.get(sev, 0) + 1
for sev in ["CRITICAL", "HIGH", "MEDIUM", "LOW", "INFO"]:
count = severity_counts.get(sev, 0)
if count > 0:
sev_table.add_row(f"[{severity_colors.get(sev, 'white')}]{sev}[/{severity_colors.get(sev, 'white')}]", str(count))
console.print(sev_table)
# Findings table
if result.filtered_findings:
findings_table = Table(title="Findings", box=box.SIMPLE, show_lines=True)
findings_table.add_column("#", style="dim", width=4)
findings_table.add_column("Severity", width=10)
findings_table.add_column("Title", width=40)
findings_table.add_column("File", width=30)
findings_table.add_column("Line", width=6)
findings_table.add_column("Source", width=15)
for i, f in enumerate(result.filtered_findings[:50], 1):
sev = f.get("severity", "MEDIUM").upper()
sev_style = severity_colors.get(sev, "white")
findings_table.add_row(
str(i),
f"[{sev_style}]{sev}[/{sev_style}]",
f.get("title", "Unknown")[:40],
f.get("file", "N/A")[-30:],
str(f.get("line", "N/A")),
f.get("source", f.get("agent", "unknown"))[:15],
)
console.print(findings_table)
# Auto-fix suggestions
if result.fixes:
console.print(f"\n[bold green]Auto-fix Suggestions ({len(result.fixes)}):[/bold green]")
for i, fix in enumerate(result.fixes[:20], 1):
console.print(f" {i}. [{severity_colors.get(fix.get('severity', 'INFO'), 'dim')}]{fix.get('title', 'Unknown')}[/{severity_colors.get(fix.get('severity', 'INFO'), 'dim')}]")
if fix.get("fix"):
console.print(f" [dim]Fix: {fix['fix'][:100]}[/dim]")
# Report locations
if result.report_files:
console.print("\n[bold]Reports generated:[/bold]")
for fmt, path in result.report_files.items():
console.print(f" [{fmt}] {path}")
# Stats
if result.stats:
console.print(f"\n[dim]Stats: {result.stats}[/dim]")
return 0 if result.risk_score < 75 else 1
async def _run_scan_simple(args) -> int:
"""Simple scan output without Rich library."""
config = ShieldConfig(config_path=args.config) if args.config else ShieldConfig()
config.target_path = args.target
if args.provider:
config.llm.provider = args.provider
if args.output:
config.report.output_dir = args.output
if args.no_cache:
config.cache.enabled = False
if args.no_dedup:
config.deduplication.enabled = False
if args.no_ignore:
config.shieldignore.enabled = False
if args.sarif_only:
config.report.formats = ["sarif"]
if args.ci:
config.report.formats = ["sarif", "json"]
setup_logging(args.verbose, args.debug)
print(f"Shield Agents v{__version__} - Scanning {args.target}...")
orchestrator = Orchestrator(config)
result = await orchestrator.scan(
target_path=args.target,
full_scan=args.full,
generate_report=not args.no_report,
generate_fixes=args.fix,
)
# JSON format: output to stdout and exit
if args.format == "json":
output = result.to_dict()
print(json.dumps(output, indent=2))
threshold = getattr(args, 'fail_threshold', 75) or 75
return 1 if result.risk_score >= threshold else 0
print(f"\nRisk Score: {result.risk_score}/100")
print(f"Files scanned: {result.files_scanned}")
print(f"Findings: {len(result.filtered_findings)}")
print(f"Duration: {result.scan_duration:.2f}s")
if result.report_files:
print("\nReports:")
for fmt, path in result.report_files.items():
print(f" [{fmt}] {path}")
return 0 if result.risk_score < 75 else 1
def run_init(args) -> int:
"""Initialize Shield Agents configuration."""
path = args.path
template_path = ShieldIgnore.create_template(path)
print(f"Created .shieldignore template at: {template_path}")
# Also create a default config
config_path = Path(path) / "config.yaml"
if not config_path.exists():
config_content = """# Shield Agents Configuration
llm:
provider: mock # mock, openai, anthropic, ollama
# api_key: YOUR_API_KEY # Or set SHIELD_LLM_API_KEY env var
model: gpt-4
temperature: 0.1
scanner:
sast_enabled: true
secrets_enabled: true
cache:
enabled: true
incremental: true
deduplication:
enabled: true
merge_sources: true
report:
output_dir: ./shield-reports
formats:
- html
- sarif
- json
"""
Path(path).mkdir(parents=True, exist_ok=True)
config_path.write_text(config_content)
print(f"Created config template at: {config_path}")
return 0
def run_cache(args) -> int:
"""Manage scan cache."""
cache = ScanCache()
cache._ensure_loaded()
if args.clear:
cache.clear()
cache.save()
print("Cache cleared.")
elif args.stats:
stats = cache.get_stats()
print(f"Cached files: {stats['cached_files']}")
print(f"Total cached findings: {stats['total_cached_findings']}")
print(f"Cache size: {stats['cache_size_bytes']} bytes")
else:
print("Use --clear or --stats")
return 0
def main():
"""Main entry point."""
parser = create_parser()
args = parser.parse_args()
if args.command is None:
parser.print_help()
return 0
if args.command == "version":
print(f"Shield Agents v{__version__}")
return 0
if args.command == "init":
return run_init(args)
if args.command == "cache":
return run_cache(args)
if args.command == "scan":
try:
if args.ci:
return asyncio.run(run_scan_ci(args))
return asyncio.run(run_scan(args))
except KeyboardInterrupt:
print("\nScan interrupted.")
return 130
return 0
if __name__ == "__main__":
sys.exit(main())