Spaces:
Runtime error
Runtime error
| """Zone-based network routing enforcement. | |
| Replaces iptables in the all-in-one container deployment. | |
| All services run on localhost; this module enforces which | |
| zones can reach which other zones on which ports. | |
| The agent experiences identical training signal to a | |
| multi-container setup with real iptables rules. | |
| All routing data comes from the snapshot/manifest topology. | |
| No hardcoded infrastructure constants. | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| from dataclasses import dataclass, field | |
| from typing import Any | |
| logger = logging.getLogger(__name__) | |
| class ZoneRouter: | |
| """Enforces network zone routing policy. | |
| Must be constructed via ``from_snapshot()`` or ``from_manifest()`` | |
| to load topology-driven routes and host-zone mappings. The bare | |
| constructor creates an empty (deny-all) router. | |
| """ | |
| routes: dict[tuple[str, str], set[int]] = field(default_factory=dict) | |
| host_zones: dict[str, str] = field(default_factory=dict) | |
| # ------------------------------------------------------------------ # | |
| # Constructors | |
| # ------------------------------------------------------------------ # | |
| def from_snapshot(cls, topology: dict[str, Any]) -> "ZoneRouter": | |
| """Build router from snapshot topology and firewall rules. | |
| This is the primary constructor. It reads ``hosts`` and | |
| ``firewall_rules`` from the topology dict to populate | |
| ``host_zones`` and ``routes``. | |
| If ``firewall_rules`` is missing or empty, a permissive default | |
| is generated: same-zone traffic is always allowed (handled by | |
| ``can_reach``), and all cross-zone traffic is denied. | |
| If a host entry lacks a ``zone`` field, its zone is inferred as | |
| ``"unknown"``. | |
| """ | |
| router = cls() | |
| # Build host_zones from topology hosts list | |
| for host in topology.get("hosts", []): | |
| if isinstance(host, dict): | |
| name = host.get("name", "") | |
| zone = host.get("zone", "unknown") | |
| if name: | |
| router.host_zones[name] = zone | |
| elif isinstance(host, str): | |
| # String-only entries get zone inferred as "unknown" | |
| router.host_zones[host] = "unknown" | |
| # Build routes from firewall_rules | |
| rules = topology.get("firewall_rules", []) | |
| if rules: | |
| for rule in rules: | |
| action = rule.get("action", "deny") | |
| if action != "allow": | |
| continue | |
| from_z = rule.get("from_zone", rule.get("from", "")) | |
| to_z = rule.get("to_zone", rule.get("to", "")) | |
| ports = set(rule.get("ports", [])) | |
| if from_z and to_z: | |
| key = (from_z, to_z) | |
| router.routes[key] = router.routes.get(key, set()) | ports | |
| # else: no firewall_rules → routes stays empty → cross-zone denied, | |
| # same-zone allowed (handled by can_reach) | |
| return router | |
| def from_manifest(cls, manifest: dict[str, Any]) -> "ZoneRouter": | |
| """Build a ZoneRouter from a raw manifest dict. | |
| Used during validation before a snapshot exists. Extracts | |
| topology from the manifest and delegates to ``from_snapshot``. | |
| """ | |
| topology = manifest.get("topology", manifest) | |
| return cls.from_snapshot(topology) | |
| # ------------------------------------------------------------------ # | |
| # Query methods | |
| # ------------------------------------------------------------------ # | |
| def can_reach(self, from_zone: str, to_zone: str, port: int) -> bool: | |
| """Check if a connection from one zone to another on a port is allowed.""" | |
| if from_zone == to_zone: | |
| return True # same zone always allowed | |
| allowed_ports = self.routes.get((from_zone, to_zone), set()) | |
| return port in allowed_ports | |
| def get_zone(self, host: str) -> str: | |
| """Get the zone for a host.""" | |
| return self.host_zones.get(host, "unknown") | |
| def check_command_access(self, from_host: str, target_host: str, port: int = 0) -> tuple[bool, str]: | |
| """Check if from_host can access target_host on port. | |
| Returns (allowed, reason). | |
| Unknown zones are denied (fail-closed). | |
| """ | |
| from_zone = self.get_zone(from_host) | |
| to_zone = self.get_zone(target_host) | |
| if from_zone == "unknown" or to_zone == "unknown": | |
| unknown = from_zone if from_zone == "unknown" else to_zone | |
| return False, f"unknown zone: {unknown}" | |
| if self.can_reach(from_zone, to_zone, port): | |
| logger.debug("ALLOW %s(%s) -> %s(%s):%d", from_host, from_zone, target_host, to_zone, port) | |
| return True, "allowed" | |
| else: | |
| logger.info("BLOCK %s(%s) -> %s(%s):%d", from_host, from_zone, target_host, to_zone, port) | |
| return False, f"Zone {from_zone} cannot reach {to_zone} on port {port}" | |