| """ |
| Command-line interface for Shield Agents. |
| |
| Provides a beautiful, interactive CLI using the Rich library for |
| progress bars, tables, and colored output. |
| """ |
|
|
| import argparse |
| import asyncio |
| import json |
| import logging |
| import sys |
| from pathlib import Path |
| from typing import Optional |
|
|
| from . import __version__ |
| from .config import ShieldConfig |
| from .orchestrator import Orchestrator |
| from .shieldignore import ShieldIgnore |
| from .cache import ScanCache |
|
|
| logger = logging.getLogger("shield_agents.cli") |
|
|
|
|
| def setup_logging(verbose: bool = False, debug: bool = False): |
| """Configure logging based on verbosity.""" |
| level = logging.DEBUG if debug else (logging.INFO if verbose else logging.WARNING) |
| logging.basicConfig( |
| level=level, |
| format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", |
| datefmt="%H:%M:%S", |
| ) |
|
|
|
|
| def create_parser() -> argparse.ArgumentParser: |
| """Create the CLI argument parser.""" |
| parser = argparse.ArgumentParser( |
| prog="shield-agents", |
| description="Shield Agents - AI-Powered Multi-Agent Cybersecurity Scanner", |
| formatter_class=argparse.RawDescriptionHelpFormatter, |
| epilog=""" |
| Examples: |
| shield-agents scan ./my-project # Scan a project |
| shield-agents scan ./my-project --full # Full scan (ignore cache) |
| shield-agents scan ./my-project --fix # Scan with auto-fix suggestions |
| shield-agents scan ./my-project --sarif-only # Output only SARIF for GitHub |
| shield-agents init # Create .shieldignore template |
| shield-agents cache --clear # Clear scan cache |
| shield-agents cache --stats # Show cache statistics |
| shield-agents version # Show version info |
| """, |
| ) |
|
|
| subparsers = parser.add_subparsers(dest="command", help="Available commands") |
|
|
| |
| scan_parser = subparsers.add_parser("scan", help="Run security scan") |
| scan_parser.add_argument("target", help="Path to scan") |
| scan_parser.add_argument("--config", "-c", help="Path to config YAML file") |
| scan_parser.add_argument("--full", action="store_true", help="Full scan (ignore cache)") |
| scan_parser.add_argument("--fix", action="store_true", help="Generate auto-fix suggestions") |
| scan_parser.add_argument("--no-report", action="store_true", help="Skip report generation") |
| scan_parser.add_argument("--sarif-only", action="store_true", help="Output only SARIF format") |
| scan_parser.add_argument("--output", "-o", help="Output directory for reports") |
| scan_parser.add_argument("--provider", choices=["mock", "openai", "anthropic", "ollama"], help="LLM provider") |
| scan_parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output") |
| scan_parser.add_argument("--debug", action="store_true", help="Debug output") |
| scan_parser.add_argument("--no-cache", action="store_true", help="Disable caching") |
| scan_parser.add_argument("--no-dedup", action="store_true", help="Disable deduplication") |
| scan_parser.add_argument("--no-ignore", action="store_true", help="Ignore .shieldignore rules") |
| scan_parser.add_argument("--ci", action="store_true", help="CI/CD mode: SARIF output, JSON to stdout, silent except errors") |
| scan_parser.add_argument("--format", "-f", choices=["rich", "json", "sarif", "plain"], default="rich", help="Output format (default: rich)") |
| scan_parser.add_argument("--fail-threshold", type=int, default=75, help="Risk score threshold for CI failure (default: 75)") |
|
|
| |
| init_parser = subparsers.add_parser("init", help="Initialize Shield Agents configuration") |
| init_parser.add_argument("--path", default=".", help="Directory to initialize") |
|
|
| |
| cache_parser = subparsers.add_parser("cache", help="Manage scan cache") |
| cache_parser.add_argument("--clear", action="store_true", help="Clear the cache") |
| cache_parser.add_argument("--stats", action="store_true", help="Show cache statistics") |
|
|
| |
| subparsers.add_parser("version", help="Show version information") |
|
|
| return parser |
|
|
|
|
| async def run_scan_ci(args) -> int: |
| """CI/CD optimized scan - silent except for JSON output on stdout.""" |
| config = ShieldConfig(config_path=args.config) if args.config else ShieldConfig() |
| config.target_path = args.target |
|
|
| |
| if args.provider: |
| config.llm.provider = args.provider |
| if args.output: |
| config.report.output_dir = args.output |
| if args.no_cache: |
| config.cache.enabled = False |
| if args.no_dedup: |
| config.deduplication.enabled = False |
|
|
| |
| config.report.formats = ["sarif", "json"] if not getattr(args, 'sarif_only', False) else ["sarif"] |
|
|
| |
| setup_logging(verbose=False, debug=False) |
|
|
| orchestrator = Orchestrator(config) |
| result = await orchestrator.scan( |
| target_path=args.target, |
| full_scan=args.full, |
| generate_report=True, |
| generate_fixes=args.fix, |
| ) |
|
|
| |
| output = result.to_dict() |
| print(json.dumps(output, indent=2)) |
|
|
| |
| threshold = getattr(args, 'fail_threshold', 75) or 75 |
| return 1 if result.risk_score >= threshold else 0 |
|
|
|
|
| async def run_scan(args) -> int: |
| """Execute the scan command.""" |
| |
| if getattr(args, 'format', 'rich') == 'json': |
| return await run_scan_ci(args) |
|
|
| try: |
| from rich.console import Console |
| from rich.progress import Progress, SpinnerColumn, TextColumn |
| from rich.table import Table |
| from rich.panel import Panel |
| from rich import box |
| except ImportError: |
| |
| return await _run_scan_simple(args) |
|
|
| console = Console() |
|
|
| |
| config = ShieldConfig(config_path=args.config) if args.config else ShieldConfig() |
| config.target_path = args.target |
|
|
| |
| if args.provider: |
| config.llm.provider = args.provider |
| if args.output: |
| config.report.output_dir = args.output |
| if args.no_cache: |
| config.cache.enabled = False |
| if args.no_dedup: |
| config.deduplication.enabled = False |
| if args.no_ignore: |
| config.shieldignore.enabled = False |
| if args.sarif_only: |
| config.report.formats = ["sarif"] |
| if args.ci: |
| config.report.formats = ["sarif", "json"] |
|
|
| config.verbose = args.verbose |
| config.debug = args.debug |
| setup_logging(args.verbose, args.debug) |
|
|
| |
| console.print(Panel.fit( |
| "[bold blue]Shield Agents[/bold blue] - AI-Powered Security Scanner\n" |
| f"[dim]Version {__version__} | Provider: {config.llm.provider}[/dim]", |
| border_style="blue", |
| )) |
|
|
| |
| orchestrator = Orchestrator(config) |
|
|
| with Progress( |
| SpinnerColumn(), |
| TextColumn("[progress.description]{task.description}"), |
| console=console, |
| ) as progress: |
| task = progress.add_task("Scanning...", total=None) |
| result = await orchestrator.scan( |
| target_path=args.target, |
| full_scan=args.full, |
| generate_report=not args.no_report, |
| generate_fixes=args.fix, |
| ) |
| progress.update(task, completed=True) |
|
|
| |
| severity_colors = { |
| "CRITICAL": "bold red", |
| "HIGH": "red", |
| "MEDIUM": "yellow", |
| "LOW": "cyan", |
| "INFO": "dim", |
| } |
|
|
| |
| risk_color = "red" if result.risk_score >= 75 else "yellow" if result.risk_score >= 50 else "green" |
| console.print(Panel.fit( |
| f"[{risk_color}]Risk Score: {result.risk_score}/100[/{risk_color}]\n" |
| f"Files scanned: {result.files_scanned}\n" |
| f"Total findings: {len(result.filtered_findings)}\n" |
| f"Scan duration: {result.scan_duration:.2f}s", |
| title="Scan Summary", |
| )) |
|
|
| |
| sev_table = Table(title="Severity Breakdown", box=box.SIMPLE) |
| sev_table.add_column("Severity", style="bold") |
| sev_table.add_column("Count", justify="right") |
|
|
| severity_counts = {} |
| for f in result.filtered_findings: |
| sev = f.get("severity", "MEDIUM").upper() |
| severity_counts[sev] = severity_counts.get(sev, 0) + 1 |
|
|
| for sev in ["CRITICAL", "HIGH", "MEDIUM", "LOW", "INFO"]: |
| count = severity_counts.get(sev, 0) |
| if count > 0: |
| sev_table.add_row(f"[{severity_colors.get(sev, 'white')}]{sev}[/{severity_colors.get(sev, 'white')}]", str(count)) |
|
|
| console.print(sev_table) |
|
|
| |
| if result.filtered_findings: |
| findings_table = Table(title="Findings", box=box.SIMPLE, show_lines=True) |
| findings_table.add_column("#", style="dim", width=4) |
| findings_table.add_column("Severity", width=10) |
| findings_table.add_column("Title", width=40) |
| findings_table.add_column("File", width=30) |
| findings_table.add_column("Line", width=6) |
| findings_table.add_column("Source", width=15) |
|
|
| for i, f in enumerate(result.filtered_findings[:50], 1): |
| sev = f.get("severity", "MEDIUM").upper() |
| sev_style = severity_colors.get(sev, "white") |
| findings_table.add_row( |
| str(i), |
| f"[{sev_style}]{sev}[/{sev_style}]", |
| f.get("title", "Unknown")[:40], |
| f.get("file", "N/A")[-30:], |
| str(f.get("line", "N/A")), |
| f.get("source", f.get("agent", "unknown"))[:15], |
| ) |
|
|
| console.print(findings_table) |
|
|
| |
| if result.fixes: |
| console.print(f"\n[bold green]Auto-fix Suggestions ({len(result.fixes)}):[/bold green]") |
| for i, fix in enumerate(result.fixes[:20], 1): |
| console.print(f" {i}. [{severity_colors.get(fix.get('severity', 'INFO'), 'dim')}]{fix.get('title', 'Unknown')}[/{severity_colors.get(fix.get('severity', 'INFO'), 'dim')}]") |
| if fix.get("fix"): |
| console.print(f" [dim]Fix: {fix['fix'][:100]}[/dim]") |
|
|
| |
| if result.report_files: |
| console.print("\n[bold]Reports generated:[/bold]") |
| for fmt, path in result.report_files.items(): |
| console.print(f" [{fmt}] {path}") |
|
|
| |
| if result.stats: |
| console.print(f"\n[dim]Stats: {result.stats}[/dim]") |
|
|
| return 0 if result.risk_score < 75 else 1 |
|
|
|
|
| async def _run_scan_simple(args) -> int: |
| """Simple scan output without Rich library.""" |
| config = ShieldConfig(config_path=args.config) if args.config else ShieldConfig() |
| config.target_path = args.target |
|
|
| if args.provider: |
| config.llm.provider = args.provider |
| if args.output: |
| config.report.output_dir = args.output |
| if args.no_cache: |
| config.cache.enabled = False |
| if args.no_dedup: |
| config.deduplication.enabled = False |
| if args.no_ignore: |
| config.shieldignore.enabled = False |
| if args.sarif_only: |
| config.report.formats = ["sarif"] |
| if args.ci: |
| config.report.formats = ["sarif", "json"] |
|
|
| setup_logging(args.verbose, args.debug) |
|
|
| print(f"Shield Agents v{__version__} - Scanning {args.target}...") |
| orchestrator = Orchestrator(config) |
| result = await orchestrator.scan( |
| target_path=args.target, |
| full_scan=args.full, |
| generate_report=not args.no_report, |
| generate_fixes=args.fix, |
| ) |
|
|
| |
| if args.format == "json": |
| output = result.to_dict() |
| print(json.dumps(output, indent=2)) |
| threshold = getattr(args, 'fail_threshold', 75) or 75 |
| return 1 if result.risk_score >= threshold else 0 |
|
|
| print(f"\nRisk Score: {result.risk_score}/100") |
| print(f"Files scanned: {result.files_scanned}") |
| print(f"Findings: {len(result.filtered_findings)}") |
| print(f"Duration: {result.scan_duration:.2f}s") |
|
|
| if result.report_files: |
| print("\nReports:") |
| for fmt, path in result.report_files.items(): |
| print(f" [{fmt}] {path}") |
|
|
| return 0 if result.risk_score < 75 else 1 |
|
|
|
|
| def run_init(args) -> int: |
| """Initialize Shield Agents configuration.""" |
| path = args.path |
| template_path = ShieldIgnore.create_template(path) |
| print(f"Created .shieldignore template at: {template_path}") |
|
|
| |
| config_path = Path(path) / "config.yaml" |
| if not config_path.exists(): |
| config_content = """# Shield Agents Configuration |
| llm: |
| provider: mock # mock, openai, anthropic, ollama |
| # api_key: YOUR_API_KEY # Or set SHIELD_LLM_API_KEY env var |
| model: gpt-4 |
| temperature: 0.1 |
| |
| scanner: |
| sast_enabled: true |
| secrets_enabled: true |
| |
| cache: |
| enabled: true |
| incremental: true |
| |
| deduplication: |
| enabled: true |
| merge_sources: true |
| |
| report: |
| output_dir: ./shield-reports |
| formats: |
| - html |
| - sarif |
| - json |
| """ |
| Path(path).mkdir(parents=True, exist_ok=True) |
| config_path.write_text(config_content) |
| print(f"Created config template at: {config_path}") |
|
|
| return 0 |
|
|
|
|
| def run_cache(args) -> int: |
| """Manage scan cache.""" |
| cache = ScanCache() |
| cache._ensure_loaded() |
|
|
| if args.clear: |
| cache.clear() |
| cache.save() |
| print("Cache cleared.") |
| elif args.stats: |
| stats = cache.get_stats() |
| print(f"Cached files: {stats['cached_files']}") |
| print(f"Total cached findings: {stats['total_cached_findings']}") |
| print(f"Cache size: {stats['cache_size_bytes']} bytes") |
| else: |
| print("Use --clear or --stats") |
|
|
| return 0 |
|
|
|
|
| def main(): |
| """Main entry point.""" |
| parser = create_parser() |
| args = parser.parse_args() |
|
|
| if args.command is None: |
| parser.print_help() |
| return 0 |
|
|
| if args.command == "version": |
| print(f"Shield Agents v{__version__}") |
| return 0 |
|
|
| if args.command == "init": |
| return run_init(args) |
|
|
| if args.command == "cache": |
| return run_cache(args) |
|
|
| if args.command == "scan": |
| try: |
| if args.ci: |
| return asyncio.run(run_scan_ci(args)) |
| return asyncio.run(run_scan(args)) |
| except KeyboardInterrupt: |
| print("\nScan interrupted.") |
| return 130 |
|
|
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |
|
|