Spaces:
Sleeping
Sleeping
| """ | |
| models.py: Core data structures shared across the scanner. | |
| A scan flows through three object types: | |
| Probe -> a single adversarial input plus the criteria for deciding | |
| whether the model failed it (defined declaratively in YAML). | |
| Finding -> the result of running one probe against the target when the | |
| model's response indicates a vulnerability (severity-tagged, | |
| with evidence and remediation). | |
| ScanResult -> the aggregate of every probe outcome for one scan run, with | |
| summary statistics used by the reporters and governance docs. | |
| Keeping these decoupled from the probe logic and the I/O layer is what lets the | |
| same finding objects feed the JSON report, the HTML report, the risk register | |
| and the model card without any of those knowing about each other. | |
| """ | |
| from __future__ import annotations | |
| import enum | |
| from dataclasses import dataclass, field, asdict | |
| from datetime import datetime, timezone | |
| from typing import Any, Dict, List, Optional | |
| class Severity(enum.Enum): | |
| """Severity ordering, highest first. The integer rank drives sorting and | |
| the CI `--fail-on` threshold.""" | |
| CRITICAL = 4 | |
| HIGH = 3 | |
| MEDIUM = 2 | |
| LOW = 1 | |
| INFO = 0 | |
| def from_str(cls, value: str) -> "Severity": | |
| try: | |
| return cls[value.strip().upper()] | |
| except KeyError as exc: | |
| valid = ", ".join(s.name for s in cls) | |
| raise ValueError( | |
| f"Unknown severity {value!r}. Valid values: {valid}" | |
| ) from exc | |
| # Order by the integer rank so severities sort and `max()` directly. A plain | |
| # Enum is unordered; defining __lt__ keeps every comparison in one place. | |
| def __lt__(self, other: "Severity") -> bool: | |
| if not isinstance(other, Severity): | |
| return NotImplemented | |
| return self.value < other.value | |
| def __str__(self) -> str: # pragma: no cover - trivial | |
| return self.name | |
| class Probe: | |
| """A single declarative test case loaded from a probe pack. | |
| Attributes: | |
| id: Stable, unique identifier (e.g. ``pi-001``). Used in reports and | |
| for suppression/allow-listing. | |
| category: The test battery this probe belongs to (e.g. | |
| ``prompt_injection``). Maps 1:1 to a detector. | |
| name: Short human-readable label. | |
| severity: Severity assigned to a *failure* of this probe. | |
| prompt: The adversarial input sent to the model under test. | |
| detector: Name of the detector function used to judge the response. | |
| detector_args: Detector-specific parameters (e.g. the canary token a | |
| leak detector should search for). | |
| description: What weakness this probe targets. | |
| remediation: Actionable fix shown on the finding when it triggers. | |
| owasp: Optional OWASP LLM Top 10 reference (e.g. ``LLM01``). | |
| context: Optional "retrieved"/tool content for indirect-injection | |
| probes, kept separate from the user ``prompt`` so the stub and real | |
| providers can model a realistic RAG/tool boundary. | |
| """ | |
| id: str | |
| category: str | |
| name: str | |
| severity: Severity | |
| prompt: str | |
| detector: str | |
| detector_args: Dict[str, Any] = field(default_factory=dict) | |
| description: str = "" | |
| remediation: str = "" | |
| owasp: str = "" | |
| context: Optional[str] = None | |
| def from_dict(cls, raw: Dict[str, Any], *, category: str) -> "Probe": | |
| missing = [k for k in ("id", "name", "prompt", "detector") if k not in raw] | |
| if missing: | |
| raise ValueError( | |
| f"Probe in category {category!r} missing required field(s): " | |
| f"{', '.join(missing)}" | |
| ) | |
| return cls( | |
| id=raw["id"], | |
| category=category, | |
| name=raw["name"], | |
| severity=Severity.from_str(raw.get("severity", "MEDIUM")), | |
| prompt=raw["prompt"], | |
| detector=raw["detector"], | |
| detector_args=dict(raw.get("detector_args", {})), | |
| description=raw.get("description", ""), | |
| remediation=raw.get("remediation", ""), | |
| owasp=raw.get("owasp", ""), | |
| context=raw.get("context"), | |
| ) | |
| class Finding: | |
| """A vulnerability surfaced by a probe whose detector judged the response | |
| as a failure.""" | |
| probe_id: str | |
| category: str | |
| name: str | |
| severity: Severity | |
| description: str | |
| evidence: str | |
| remediation: str | |
| prompt: str | |
| response: str | |
| owasp: str = "" | |
| detector: str = "" | |
| def to_dict(self) -> Dict[str, Any]: | |
| d = asdict(self) | |
| d["severity"] = self.severity.name | |
| return d | |
| class ProbeOutcome: | |
| """Outcome of running a single probe: failed or not. Non-failures are | |
| retained so the report can show coverage (tests passed vs. failed), not | |
| just the bad news.""" | |
| probe: Probe | |
| response: str | |
| failed: bool | |
| finding: Optional[Finding] = None | |
| class ScanResult: | |
| """Aggregate result of one full scan run.""" | |
| target: str | |
| started_at: str | |
| finished_at: str | |
| outcomes: List[ProbeOutcome] = field(default_factory=list) | |
| scanner_version: str = "" | |
| # ------------------------------------------------------------------ # | |
| # Derived views | |
| # ------------------------------------------------------------------ # | |
| def findings(self) -> List[Finding]: | |
| items = [o.finding for o in self.outcomes if o.finding is not None] | |
| return sorted(items, key=lambda f: (-f.severity.value, f.category, f.probe_id)) | |
| def total_probes(self) -> int: | |
| return len(self.outcomes) | |
| def total_findings(self) -> int: | |
| return len(self.findings) | |
| def severity_counts(self) -> Dict[str, int]: | |
| """Count of findings per severity, always including every level so the | |
| report tables are stable.""" | |
| counts = {s.name: 0 for s in Severity} | |
| for f in self.findings: | |
| counts[f.severity.name] += 1 | |
| return counts | |
| def category_counts(self) -> Dict[str, int]: | |
| counts: Dict[str, int] = {} | |
| for f in self.findings: | |
| counts[f.category] = counts.get(f.category, 0) + 1 | |
| return counts | |
| def pass_rate(self) -> float: | |
| if not self.outcomes: | |
| return 1.0 | |
| passed = sum(1 for o in self.outcomes if not o.failed) | |
| return passed / len(self.outcomes) | |
| def highest_severity(self) -> Optional[Severity]: | |
| if not self.findings: | |
| return None | |
| return max(f.severity for f in self.findings) | |
| def to_dict(self) -> Dict[str, Any]: | |
| return { | |
| "target": self.target, | |
| "scanner_version": self.scanner_version, | |
| "started_at": self.started_at, | |
| "finished_at": self.finished_at, | |
| "summary": { | |
| "total_probes": self.total_probes, | |
| "total_findings": self.total_findings, | |
| "pass_rate": round(self.pass_rate, 4), | |
| "severity_counts": self.severity_counts(), | |
| "category_counts": self.category_counts(), | |
| "highest_severity": ( | |
| self.highest_severity().name if self.highest_severity() else None | |
| ), | |
| }, | |
| "findings": [f.to_dict() for f in self.findings], | |
| "passed_probes": [ | |
| { | |
| "probe_id": o.probe.id, | |
| "category": o.probe.category, | |
| "name": o.probe.name, | |
| } | |
| for o in self.outcomes | |
| if not o.failed | |
| ], | |
| } | |
| def utcnow_iso() -> str: | |
| """Timezone-aware UTC timestamp, ISO-8601 with a trailing ``Z``.""" | |
| return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") | |