Demo / server /threat_graph.py
Ajayyy00
Initial commit of CyberSOC upgraded RLVR environment
57e71f8
"""ThreatGraph — typed knowledge graph of SOC entities, edges, and evidence."""
from __future__ import annotations
from datetime import datetime
from typing import Literal, Optional
from pydantic import BaseModel, Field
class HostNode(BaseModel):
hostname: str
subnet: str
business_criticality: Literal["low", "medium", "high", "critical"]
status: Literal["healthy", "suspicious", "compromised", "isolated", "contained"]
first_seen_suspicious: Optional[datetime] = None
scanned: bool = False
class ProcessNode(BaseModel):
process_id: str # format: "hostname:pid"
hostname: str
process_name: str
killed: bool = False
class IOCNode(BaseModel):
ioc_value: str
ioc_type: Literal["ip", "domain", "hash", "filename"]
confidence: float
blocked: bool = False
enriched: bool = False
threat_actor: Optional[str] = None
mitre_ttps: list[str] = Field(default_factory=list)
class VulnerabilityNode(BaseModel):
cve_id: str
hostname: str
cvss_score: float
exploitability: Literal["active", "theoretical", "patched"]
patch_available: bool
exploited_by_threat: Optional[str] = None
class AlertNode(BaseModel):
alert_id: str
severity: Literal["low", "medium", "high", "critical"]
priority_score: float
source_host: str
correlated_with: list[str] = Field(default_factory=list)
class Edge(BaseModel):
edge_type: Literal[
"runs_on", "involves", "communicates_with",
"pivoted_from", "part_of_chain", "exploits",
]
source_id: str
target_id: str
evidence: dict = Field(default_factory=dict)
MAX_GRAPH_NODES = 200
class ThreatGraph:
def __init__(self):
self.hosts: dict[str, HostNode] = {}
self.processes: dict[str, ProcessNode] = {}
self.iocs: dict[str, IOCNode] = {}
self.vulnerabilities: dict[str, VulnerabilityNode] = {}
self.alerts: dict[str, AlertNode] = {}
self.edges: list[Edge] = []
self.version: int = 0
# changelog entries: (version_after_add, entity_type, entity_id)
self._changelog: list[tuple[int, str, str]] = []
# insertion-order tracking for IOC pruning (oldest first)
self._ioc_insertion_order: list[str] = []
def _total_nodes(self) -> int:
return (
len(self.hosts) + len(self.processes) + len(self.iocs)
+ len(self.vulnerabilities) + len(self.alerts)
)
def _prune_oldest_iocs(self, needed: int = 1) -> None:
"""Remove the oldest `needed` IOCNodes to stay under MAX_GRAPH_NODES."""
pruned = 0
while pruned < needed and self._ioc_insertion_order:
oldest = self._ioc_insertion_order.pop(0)
if oldest in self.iocs:
del self.iocs[oldest]
pruned += 1
def add_host(self, node: HostNode) -> None:
self.hosts[node.hostname] = node
self.version += 1
self._changelog.append((self.version, "host", node.hostname))
def add_process(self, node: ProcessNode) -> None:
self.processes[node.process_id] = node
self.version += 1
self._changelog.append((self.version, "process", node.process_id))
def add_ioc(self, node: IOCNode) -> None:
if node.ioc_value in self.iocs:
return # already present — no cap action needed
if self._total_nodes() >= MAX_GRAPH_NODES:
self._prune_oldest_iocs(needed=1)
self.iocs[node.ioc_value] = node
self._ioc_insertion_order.append(node.ioc_value)
self.version += 1
self._changelog.append((self.version, "ioc", node.ioc_value))
def add_vulnerability(self, node: VulnerabilityNode) -> None:
key = f"{node.hostname}:{node.cve_id}"
self.vulnerabilities[key] = node
self.version += 1
self._changelog.append((self.version, "vulnerability", key))
def add_alert(self, node: AlertNode) -> None:
self.alerts[node.alert_id] = node
self.version += 1
self._changelog.append((self.version, "alert", node.alert_id))
def add_edge(self, edge: Edge) -> None:
self.edges.append(edge)
self.version += 1
edge_id = f"{edge.edge_type}:{edge.source_id}->{edge.target_id}"
self._changelog.append((self.version, "edge", edge_id))
def delta_since(self, version: int) -> dict:
"""Return compact summary of nodes/edges added since `version`."""
if version <= 0:
entries = list(self._changelog)
else:
entries = [e for e in self._changelog if e[0] > version]
counts: dict[str, int] = {}
ids_by_type: dict[str, list[str]] = {}
for _, etype, eid in entries:
counts[etype] = counts.get(etype, 0) + 1
ids_by_type.setdefault(etype, []).append(eid)
# Truncate each id list to keep summary compact
compact_ids = {k: v[:5] for k, v in ids_by_type.items()}
summary_parts = [f"{t}={counts[t]}" for t in sorted(counts.keys())]
summary_text = f"Δ since v{version}: " + ", ".join(summary_parts) if summary_parts else f"Δ since v{version}: (no changes)"
return {
"from_version": version,
"to_version": self.version,
"counts": counts,
"ids": compact_ids,
"summary": summary_text,
}
def compute_evidence_confidence(
self, threat_id: str, rubric_item_count: int = 3
) -> float:
"""Confidence that a threat is well-evidenced.
Denominator is normalized to task complexity: max(3, rubric_item_count * 1.5).
This prevents reward hacking via forensics spam — a spammer who generates
10 graph nodes only scores against the rubric-sized baseline, not 10.
"""
linked_ids: set[str] = set()
for edge in self.edges:
if edge.source_id == threat_id:
linked_ids.add(edge.target_id)
elif edge.target_id == threat_id:
linked_ids.add(edge.source_id)
if not linked_ids:
return 0.0
non_alert_count = 0
for nid in linked_ids:
if nid in self.alerts:
continue
if (
nid in self.hosts
or nid in self.processes
or nid in self.iocs
or nid in self.vulnerabilities
):
non_alert_count += 1
denominator = max(3.0, rubric_item_count * 1.5)
confidence = non_alert_count / denominator
return max(0.0, min(1.0, confidence))
def get_context_summary(self) -> str:
"""Compact LLM-injectable summary of current graph state."""
compromised = sum(1 for h in self.hosts.values() if h.status == "compromised")
critical_alerts = sum(1 for a in self.alerts.values() if a.severity == "critical")
blocked = sum(1 for i in self.iocs.values() if i.blocked)
enriched = sum(1 for i in self.iocs.values() if i.enriched)
return (
f"Hosts: {len(self.hosts)} ({compromised} compromised) "
f"Alerts: {len(self.alerts)} ({critical_alerts} critical) "
f"IOCs: {len(self.iocs)} ({blocked} blocked, {enriched} enriched) "
f"Vulns: {len(self.vulnerabilities)} "
f"Edges: {len(self.edges)}"
)