| """ThreatGraph — typed knowledge graph of SOC entities, edges, and evidence.""" |
|
|
| from __future__ import annotations |
|
|
| from datetime import datetime |
| from typing import Literal, Optional |
|
|
| from pydantic import BaseModel, Field |
|
|
|
|
| class HostNode(BaseModel): |
| hostname: str |
| subnet: str |
| business_criticality: Literal["low", "medium", "high", "critical"] |
| status: Literal["healthy", "suspicious", "compromised", "isolated", "contained"] |
| first_seen_suspicious: Optional[datetime] = None |
| scanned: bool = False |
|
|
|
|
| class ProcessNode(BaseModel): |
| process_id: str |
| hostname: str |
| process_name: str |
| killed: bool = False |
|
|
|
|
| class IOCNode(BaseModel): |
| ioc_value: str |
| ioc_type: Literal["ip", "domain", "hash", "filename"] |
| confidence: float |
| blocked: bool = False |
| enriched: bool = False |
| threat_actor: Optional[str] = None |
| mitre_ttps: list[str] = Field(default_factory=list) |
|
|
|
|
| class VulnerabilityNode(BaseModel): |
| cve_id: str |
| hostname: str |
| cvss_score: float |
| exploitability: Literal["active", "theoretical", "patched"] |
| patch_available: bool |
| exploited_by_threat: Optional[str] = None |
|
|
|
|
| class AlertNode(BaseModel): |
| alert_id: str |
| severity: Literal["low", "medium", "high", "critical"] |
| priority_score: float |
| source_host: str |
| correlated_with: list[str] = Field(default_factory=list) |
|
|
|
|
| class Edge(BaseModel): |
| edge_type: Literal[ |
| "runs_on", "involves", "communicates_with", |
| "pivoted_from", "part_of_chain", "exploits", |
| ] |
| source_id: str |
| target_id: str |
| evidence: dict = Field(default_factory=dict) |
|
|
|
|
| MAX_GRAPH_NODES = 200 |
|
|
|
|
| class ThreatGraph: |
| def __init__(self): |
| self.hosts: dict[str, HostNode] = {} |
| self.processes: dict[str, ProcessNode] = {} |
| self.iocs: dict[str, IOCNode] = {} |
| self.vulnerabilities: dict[str, VulnerabilityNode] = {} |
| self.alerts: dict[str, AlertNode] = {} |
| self.edges: list[Edge] = [] |
| self.version: int = 0 |
| |
| self._changelog: list[tuple[int, str, str]] = [] |
| |
| self._ioc_insertion_order: list[str] = [] |
|
|
| def _total_nodes(self) -> int: |
| return ( |
| len(self.hosts) + len(self.processes) + len(self.iocs) |
| + len(self.vulnerabilities) + len(self.alerts) |
| ) |
|
|
| def _prune_oldest_iocs(self, needed: int = 1) -> None: |
| """Remove the oldest `needed` IOCNodes to stay under MAX_GRAPH_NODES.""" |
| pruned = 0 |
| while pruned < needed and self._ioc_insertion_order: |
| oldest = self._ioc_insertion_order.pop(0) |
| if oldest in self.iocs: |
| del self.iocs[oldest] |
| pruned += 1 |
|
|
| def add_host(self, node: HostNode) -> None: |
| self.hosts[node.hostname] = node |
| self.version += 1 |
| self._changelog.append((self.version, "host", node.hostname)) |
|
|
| def add_process(self, node: ProcessNode) -> None: |
| self.processes[node.process_id] = node |
| self.version += 1 |
| self._changelog.append((self.version, "process", node.process_id)) |
|
|
| def add_ioc(self, node: IOCNode) -> None: |
| if node.ioc_value in self.iocs: |
| return |
| if self._total_nodes() >= MAX_GRAPH_NODES: |
| self._prune_oldest_iocs(needed=1) |
| self.iocs[node.ioc_value] = node |
| self._ioc_insertion_order.append(node.ioc_value) |
| self.version += 1 |
| self._changelog.append((self.version, "ioc", node.ioc_value)) |
|
|
| def add_vulnerability(self, node: VulnerabilityNode) -> None: |
| key = f"{node.hostname}:{node.cve_id}" |
| self.vulnerabilities[key] = node |
| self.version += 1 |
| self._changelog.append((self.version, "vulnerability", key)) |
|
|
| def add_alert(self, node: AlertNode) -> None: |
| self.alerts[node.alert_id] = node |
| self.version += 1 |
| self._changelog.append((self.version, "alert", node.alert_id)) |
|
|
| def add_edge(self, edge: Edge) -> None: |
| self.edges.append(edge) |
| self.version += 1 |
| edge_id = f"{edge.edge_type}:{edge.source_id}->{edge.target_id}" |
| self._changelog.append((self.version, "edge", edge_id)) |
|
|
| def delta_since(self, version: int) -> dict: |
| """Return compact summary of nodes/edges added since `version`.""" |
| if version <= 0: |
| entries = list(self._changelog) |
| else: |
| entries = [e for e in self._changelog if e[0] > version] |
|
|
| counts: dict[str, int] = {} |
| ids_by_type: dict[str, list[str]] = {} |
| for _, etype, eid in entries: |
| counts[etype] = counts.get(etype, 0) + 1 |
| ids_by_type.setdefault(etype, []).append(eid) |
|
|
| |
| compact_ids = {k: v[:5] for k, v in ids_by_type.items()} |
|
|
| summary_parts = [f"{t}={counts[t]}" for t in sorted(counts.keys())] |
| summary_text = f"Δ since v{version}: " + ", ".join(summary_parts) if summary_parts else f"Δ since v{version}: (no changes)" |
|
|
| return { |
| "from_version": version, |
| "to_version": self.version, |
| "counts": counts, |
| "ids": compact_ids, |
| "summary": summary_text, |
| } |
|
|
| def compute_evidence_confidence( |
| self, threat_id: str, rubric_item_count: int = 3 |
| ) -> float: |
| """Confidence that a threat is well-evidenced. |
| |
| Denominator is normalized to task complexity: max(3, rubric_item_count * 1.5). |
| This prevents reward hacking via forensics spam — a spammer who generates |
| 10 graph nodes only scores against the rubric-sized baseline, not 10. |
| """ |
| linked_ids: set[str] = set() |
| for edge in self.edges: |
| if edge.source_id == threat_id: |
| linked_ids.add(edge.target_id) |
| elif edge.target_id == threat_id: |
| linked_ids.add(edge.source_id) |
|
|
| if not linked_ids: |
| return 0.0 |
|
|
| non_alert_count = 0 |
| for nid in linked_ids: |
| if nid in self.alerts: |
| continue |
| if ( |
| nid in self.hosts |
| or nid in self.processes |
| or nid in self.iocs |
| or nid in self.vulnerabilities |
| ): |
| non_alert_count += 1 |
|
|
| denominator = max(3.0, rubric_item_count * 1.5) |
| confidence = non_alert_count / denominator |
| return max(0.0, min(1.0, confidence)) |
|
|
| def get_context_summary(self) -> str: |
| """Compact LLM-injectable summary of current graph state.""" |
| compromised = sum(1 for h in self.hosts.values() if h.status == "compromised") |
| critical_alerts = sum(1 for a in self.alerts.values() if a.severity == "critical") |
| blocked = sum(1 for i in self.iocs.values() if i.blocked) |
| enriched = sum(1 for i in self.iocs.values() if i.enriched) |
| return ( |
| f"Hosts: {len(self.hosts)} ({compromised} compromised) " |
| f"Alerts: {len(self.alerts)} ({critical_alerts} critical) " |
| f"IOCs: {len(self.iocs)} ({blocked} blocked, {enriched} enriched) " |
| f"Vulns: {len(self.vulnerabilities)} " |
| f"Edges: {len(self.edges)}" |
| ) |
|
|