Carlos
chore: ruff format after mypy fixes
c05edc5
Raw
History Blame Contribute Delete
16.5 kB
import asyncio
from pathlib import Path
import typer
from rich.console import Console
from rich.panel import Panel
from rich.table import Table
from scanner.certification import CertificationPipeline
from scanner.config import Settings
from scanner.monitor import Alerter, MonitorScheduler, MonitorStore
from scanner.pipeline import PipelineOrchestrator
from scanner.policies import PolicyGenerator
from scanner.redteam import AdversarialPageGenerator, ScannerEvaluator
from scanner.reporters import JSONReporter, MarkdownReporter, SimpleReporter
from scanner.reputation import ReputationEngine
from scanner.validator import AgentValidator, BehaviorEvaluator
app = typer.Typer(
name="pis",
help="Prompt Injection Scanner β€” Analyze content for AI agent threats",
no_args_is_help=True,
)
console = Console()
@app.command()
def scan(
url: str = typer.Argument(None, help="URL to scan"),
file: str = typer.Option(None, "--file", "-f", help="Local file to scan"),
paste: str = typer.Option(None, "--paste", "-p", help="Raw text to scan"),
format: str = typer.Option("rich", "--format", help="Output format: rich, json, markdown, simple"),
output: str = typer.Option(None, "--output", "-o", help="Save output to file"),
ci: bool = typer.Option(False, "--ci", help="CI mode: exit code reflects risk"),
threshold: str = typer.Option("high", "--threshold", help="CI failure threshold: low, medium, high, critical"),
llm: bool = typer.Option(False, "--llm", help="Enable LLM classification (requires API key)"),
verbose: bool = typer.Option(False, "--verbose", "-v", help="Verbose output"),
reputation: bool = typer.Option(False, "--reputation", "-r", help="Check and store reputation"),
):
"""Scan a URL, file, or paste for prompt injection."""
settings = Settings()
if verbose:
settings.debug = True
orchestrator = PipelineOrchestrator(settings=settings)
rep = ReputationEngine() if reputation else None
async def run():
if url:
report = await orchestrator.scan_url(url)
if rep:
rep.record_scan(url, report)
elif file:
report = await orchestrator.scan_file(file)
elif paste:
report = await orchestrator.scan_paste(paste)
else:
console.print("[red]Error:[/] provide a URL, --file, or --paste")
raise typer.Exit(1)
return report
report = asyncio.run(run())
if output:
ext = Path(output).suffix.lower()
format_map = {".json": "json", ".md": "markdown", ".html": "html"}
format = format_map.get(ext, format)
reporter_map = {
"json": JSONReporter(),
"markdown": MarkdownReporter(),
"simple": SimpleReporter(),
}
if format == "rich":
_display_rich(report, verbose)
elif format in reporter_map:
text = reporter_map[format].render(report)
if output:
Path(output).write_text(text)
console.print(f"[green]Output saved to[/] {output}")
else:
print(text)
else:
console.print(f"[red]Unknown format:[/] {format}")
raise typer.Exit(1)
if reputation and url:
info = rep.query(url) # type: ignore[union-attr]
console.print(f"\n[bold]Reputation:[/] {info['trust_level']} (score: {info['score']:.0f}/100)")
if ci:
severity_order = {"low": 1, "medium": 2, "high": 3, "critical": 4}
threshold_val = severity_order.get(threshold, 3)
cat_val = severity_order.get(report.risk_category, 0)
if cat_val >= threshold_val:
console.print(f"[red]CI FAILED:[/] Risk {report.risk_category} >= threshold {threshold}")
raise typer.Exit(1)
console.print(f"[green]CI PASSED:[/] Risk {report.risk_category} < threshold {threshold}")
@app.command()
def policies(
url: str = typer.Argument(..., help="URL to scan and generate policies for"),
output: str = typer.Option("pis-policies.yaml", "--output", "-o", help="Output file"),
):
"""Scan a URL and generate MCPGuard-compatible policy rules."""
settings = Settings()
orchestrator = PipelineOrchestrator(settings=settings)
gen = PolicyGenerator()
async def run():
report = await orchestrator.scan_url(url)
yaml = gen.to_mcpguard_yaml(report)
Path(output).write_text(yaml)
console.print(f"[green]Policies saved to[/] {output}")
asyncio.run(run())
# ─── Monitor ──────────────────────────────────────────────────────────────
@app.command()
def monitor(
url: str = typer.Argument(None, help="URL to start monitoring"),
list_urls: bool = typer.Option(False, "--list", "-l", help="List monitored URLs"),
interval: float = typer.Option(6.0, "--interval", "-i", help="Scan interval in hours"),
webhook: str = typer.Option("", "--webhook", "-w", help="Alert webhook URL"),
daemon: bool = typer.Option(False, "--daemon", "-d", help="Run as daemon (continuous monitoring)"),
):
"""Monitor URLs for changes in risk posture over time."""
store = MonitorStore()
orchestrator = PipelineOrchestrator()
alerter = Alerter(store)
if list_urls:
entries = store.get_urls()
if not entries:
console.print("[yellow]No monitored URLs[/]")
return
table = Table(title="Monitored URLs")
table.add_column("URL")
table.add_column("Score")
table.add_column("Category")
table.add_column("Scans")
table.add_column("Last Scan")
for e in entries:
table.add_row(
e["url"],
str(e["last_risk_score"]),
e["last_risk_category"],
str(e["total_scans"]),
e.get("last_scan_at", "never"),
)
console.print(table)
return
if url:
scheduler = MonitorScheduler(store, orchestrator, alerter)
scheduler.start()
console.print(f"[green]Monitoring[/] {url} every {interval}h")
if daemon:
console.print("[bold]Running...[/] Press Ctrl+C to stop")
try:
asyncio.get_event_loop().run_forever()
except KeyboardInterrupt:
scheduler.stop()
else:
scheduler.stop()
else:
console.print("[red]Provide a URL or --list[/]")
# ─── Proxy ────────────────────────────────────────────────────────────────
@app.command()
def proxy(
port: int = typer.Option(9090, "--port", "-p", help="Proxy port"),
mode: str = typer.Option("strip", "--mode", "-m", help="strip/rewrite/block/passthrough"),
):
"""Run the Content Safety Proxy server."""
from scanner.proxy.server import run_proxy
console.print(f"[green]Starting Content Safety Proxy on :{port} (mode: {mode})[/]")
run_proxy(port=port, mode=mode)
# ─── Validate ─────────────────────────────────────────────────────────────
@app.command()
def validate(
url: str = typer.Argument(..., help="URL to test an agent against"),
provider: str = typer.Option("browser_use", "--provider", help="Agent provider: browser_use, playwright"),
scan_first: bool = typer.Option(True, "--scan/--no-scan", help="Run scanner first"),
):
"""Validate a URL by running a real AI agent against it."""
orchestrator = PipelineOrchestrator()
validator = AgentValidator(provider=provider)
evaluator = BehaviorEvaluator()
async def run():
if scan_first:
report = await orchestrator.scan_url(url)
console.print(f"[bold]Scanner found[/] {report.total_findings} issues (risk: {report.risk_score}/100)")
console.print(f"[bold]Running agent[/] ({provider}) against {url}...")
session = await validator.validate(url)
findings = report.findings if scan_first else []
result = evaluator.evaluate(session, findings)
console.print("\n[bold]Agent Vulnerability Report[/]")
console.print(f" Steps: {result.total_steps}")
console.print(f" Mission success: {result.mission_success}")
console.print(f" Injections triggered: {len(result.injections_triggered)}")
console.print(f" Injections ignored: {len(result.injections_ignored)}")
console.print(f" Vulnerability score: {result.overall_vulnerability_score}/100")
if result.injections_triggered:
console.print("\n[red]Triggered injections:[/]")
for t in result.injections_triggered:
console.print(f" β€’ [{t.severity}] {t.injection_text[:80]}")
return result
asyncio.run(run())
# ─── Reputation ──────────────────────────────────────────────────────────
@app.command()
def reputation(
url: str = typer.Argument(None, help="URL to query reputation for"),
list_threats: bool = typer.Option(False, "--threats", "-t", help="List recent threats"),
):
"""Query or display reputation information."""
rep = ReputationEngine()
if list_threats:
threats = rep.recent_threats()
if not threats:
console.print("[green]No recent threats[/]")
return
table = Table(title="Recent Threats (24h)")
table.add_column("Domain")
table.add_column("Score")
table.add_column("Level")
table.add_column("Critical")
for t in threats:
table.add_row(t["domain"], f"{t['score']:.0f}", t["trust_level"], str(t["critical_findings"]))
console.print(table)
return
if url:
info = rep.query(url)
console.print(f"[bold]Reputation for[/] {url}")
console.print(f" Trust level: [bold]{info['trust_level']}[/]")
console.print(f" Score: {info['score']:.0f}/100")
console.print(f" Total scans: {info['total_scans']}")
console.print(f" Total findings: {info['total_findings']}")
else:
console.print("[red]Provide a URL or --threats[/]")
# ─── Red Team ────────────────────────────────────────────────────────────
@app.command()
def redteam(
count: int = typer.Option(5, "--count", "-c", help="Number of adversarial pages to generate"),
template: str = typer.Option("ecommerce", "--template", "-t", help="Page template: ecommerce, blog"),
):
"""Generate adversarial pages and evaluate the scanner."""
orchestrator = PipelineOrchestrator()
generator = AdversarialPageGenerator()
evaluator = ScannerEvaluator(orchestrator)
async def run():
console.print(f"[bold]Generating[/] {count} adversarial pages...")
pages = [generator.generate(template=template) for _ in range(count)]
console.print("[bold]Evaluating scanner...[/]")
result = await evaluator.evaluate(pages)
console.print("\n[bold]Scanner Evaluation Report[/]")
console.print(f" Pages: {result.total_pages}")
console.print(f" Injections: {result.total_injections}")
console.print(f" Precision: {result.precision:.1%}")
console.print(f" Recall: {result.recall:.1%}")
console.print(f" F1 Score: {result.f1:.1%}")
if result.by_category:
console.print("\n[bold]By Category:[/]")
for cat, data in result.by_category.items():
rate = data["tp"] / (data["tp"] + data["fn"]) if (data["tp"] + data["fn"]) > 0 else 0
console.print(f" {cat}: {rate:.0%} ({data['tp']}/{data['tp'] + data['fn']})")
if result.recommendations:
console.print("\n[yellow]Recommendations:[/]")
for r in result.recommendations:
console.print(f" β€’ {r}")
asyncio.run(run())
# ─── Certification ───────────────────────────────────────────────────────
@app.command()
def certify(
url: str = typer.Argument(..., help="URL to certify"),
email: str = typer.Option("", "--email", "-e", help="Owner email"),
org: str = typer.Option("", "--org", "-o", help="Organization name"),
verify: str = typer.Option(None, "--verify", "-v", help="Verify a certificate ID"),
badge: str = typer.Option(None, "--badge", help="Generate badge HTML for certificate ID"),
):
"""Apply for, verify, or generate badge for AgentSafe certification."""
orchestrator = PipelineOrchestrator()
certification = CertificationPipeline(orchestrator)
async def run():
if verify:
info = certification.verify(verify)
console.print(f"[bold]Certificate {verify}[/]")
console.print(f" Valid: {info['valid']}")
console.print(f" Status: {info['status']}")
console.print(f" URL: {info['url']}")
console.print(f" Issued: {info['issued_at']}")
console.print(f" Expires: {info['expires_at']}")
return
if badge:
html = certification.badge_html(badge)
if html:
console.print(html)
else:
console.print("[red]Certificate not found or expired[/]")
return
result = await certification.apply(url, email, org)
console.print("[bold]Certification Applied[/]")
console.print(f" Certificate ID: {result.get('certificate_id', 'N/A')}")
console.print(f" Status: {result.get('status', 'error')}")
console.print(f" Initial risk: {result.get('initial_risk_score', 'N/A')}/100")
console.print(f" Monitoring: {result.get('monitoring_period_days', 0)} days")
if "error" in result:
console.print(f"[red]{result['error']}[/]")
asyncio.run(run())
# ─── Web ─────────────────────────────────────────────────────────────────
@app.command()
def web(
port: int = typer.Option(8000, "--port", "-p", help="Web UI port"),
host: str = typer.Option("127.0.0.1", "--host", help="Bind address"),
):
"""Launch the web UI."""
import uvicorn
from scanner.api import app as web_app
console.print(f"[green]Web UI:[/] http://{host}:{port}")
uvicorn.run(web_app, host=host, port=port)
# ─── Rich Display ────────────────────────────────────────────────────────
def _display_rich(report, verbose: bool = False):
color_map = {"none": "green", "low": "yellow", "medium": "orange1", "high": "red", "critical": "bold red"}
color = color_map.get(report.risk_category, "white")
console.print(
Panel(
f"[bold]Risk Score:[/] [{color}]{report.risk_score}/100 ({report.risk_category})[/]\n"
f"[bold]URL:[/] {report.url}\n"
f"[bold]Findings:[/] {report.total_findings} | "
f"[bold]Time:[/] {report.scan_time_ms}ms",
title="Scan Results",
)
)
if report.summary:
console.print(Panel(report.summary, title="Summary"))
if report.findings:
table = Table(title=f"Findings ({len(report.findings)})")
table.add_column("Severity", style="bold")
table.add_column("Category", style="cyan")
table.add_column("Title")
table.add_column("Detector")
if verbose:
table.add_column("Snippet")
for f in report.findings:
sv = f"[red]{f.severity.upper()}[/]" if f.severity in ("critical", "high") else f"[yellow]{f.severity}[/]"
row = [sv, f.category, f.title[:60], f.detector]
if verbose:
row.append(f.snippet[:100])
table.add_row(*row)
console.print(table)
def main():
app()
if __name__ == "__main__":
main()