File size: 5,378 Bytes
80ef9e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b3b677
80ef9e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b3b677
 
 
 
80ef9e0
 
 
5b3b677
 
 
 
 
 
 
 
 
80ef9e0
 
 
5b3b677
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80ef9e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
"""Graph-native path solvability checks."""

from __future__ import annotations

from collections import defaultdict, deque

from open_range.protocols import CheckResult, ContainerSet, SnapshotSpec
from open_range.validator.graphs import CompiledGraphs, compile_snapshot_graphs


class PathSolvabilityCheck:
    """Verify that vuln and flag hosts are reachable in the compiled host graph."""

    async def check(self, snapshot: SnapshotSpec, containers: ContainerSet) -> CheckResult:
        compiled = compile_snapshot_graphs(snapshot)
        issues: list[str] = []

        if not compiled.hosts:
            return CheckResult(
                name="path_solvability",
                passed=False,
                error="snapshot has no compiled hosts",
            )

        start_hosts = _start_hosts(compiled)
        vuln_hosts = {v.host for v in snapshot.truth_graph.vulns if v.host}
        flag_hosts = {flag.host for flag in snapshot.flags if flag.host}
        target_hosts = sorted(vuln_hosts.union(flag_hosts))
        if not target_hosts:
            return CheckResult(
                name="path_solvability",
                passed=False,
                error="snapshot has no vuln or flag hosts to solve toward",
            )

        adjacency = build_host_adjacency(snapshot, compiled)
        unreachable = [
            host
            for host in target_hosts
            if not _reachable_from_any(host, start_hosts, adjacency)
        ]
        if unreachable:
            issues.append(f"unreachable target hosts from start set {sorted(start_hosts)}: {unreachable}")

        for flag_host in sorted(flag_hosts):
            if not (
                flag_host in vuln_hosts
                or _reachable_from_any(flag_host, vuln_hosts or start_hosts, adjacency)
            ):
                issues.append(
                    f"flag host '{flag_host}' is not grounded by any vuln host or start host"
                )

        passed = len(issues) == 0
        return CheckResult(
            name="path_solvability",
            passed=passed,
            details={
                "start_hosts": sorted(start_hosts),
                "target_hosts": target_hosts,
                "issues": issues,
            },
            error="" if passed else "; ".join(issues),
        )


def _start_hosts(compiled: CompiledGraphs) -> set[str]:
    starts = {
        host
        for host in compiled.hosts
        if host in {"attacker", "internet"}
        or compiled.zones_by_host.get(host) == "external"
    }
    if starts:
        return starts
    if compiled.hosts:
        return {sorted(compiled.hosts)[0]}
    return set()


def build_host_adjacency(
    snapshot: SnapshotSpec,
    compiled: CompiledGraphs,
) -> dict[str, set[str]]:
    adjacency: dict[str, set[str]] = defaultdict(set)
    for source, target in compiled.dependency_edges:
        adjacency[source].add(target)

    principal_hosts = _principal_hosts(snapshot)
    for source_principal, target_principal, _edge_type in compiled.trust_edges:
        source_hosts = principal_hosts.get(source_principal, set())
        target_hosts = principal_hosts.get(target_principal, set())
        for source_host in source_hosts:
            for target_host in target_hosts:
                if source_host and target_host:
                    adjacency[source_host].add(target_host)
    return adjacency


def has_host_path(
    start: str,
    target: str,
    adjacency: dict[str, set[str]],
) -> bool:
    return _has_path(start, target, adjacency)


def _principal_hosts(snapshot: SnapshotSpec) -> dict[str, set[str]]:
    topology = snapshot.topology or {}
    mapping: dict[str, set[str]] = defaultdict(set)

    raw_users = topology.get("users", [])
    if isinstance(raw_users, list):
        for raw in raw_users:
            if not isinstance(raw, dict):
                continue
            username = str(raw.get("username", "")).strip()
            if not username:
                continue
            for raw_host in raw.get("hosts", []):
                host = str(raw_host).strip()
                if host:
                    mapping[username].add(host)

    raw_catalog = topology.get("principal_catalog", {})
    if isinstance(raw_catalog, dict):
        for raw_name, raw_principal in raw_catalog.items():
            name = str(raw_name).strip()
            if not name or not isinstance(raw_principal, dict):
                continue
            for raw_host in raw_principal.get("hosts", []):
                host = str(raw_host).strip()
                if host:
                    mapping[name].add(host)

    return mapping


def _reachable_from_any(
    target: str,
    starts: set[str],
    adjacency: dict[str, set[str]],
) -> bool:
    for start in starts:
        if start == target:
            return True
        if _has_path(start, target, adjacency):
            return True
    return False


def _has_path(start: str, target: str, adjacency: dict[str, set[str]]) -> bool:
    queue: deque[str] = deque([start])
    seen = {start}
    while queue:
        current = queue.popleft()
        for neighbor in adjacency.get(current, set()):
            if neighbor == target:
                return True
            if neighbor in seen:
                continue
            seen.add(neighbor)
            queue.append(neighbor)
    return False