Spaces:
Runtime error
Runtime error
| """Render SnapshotSpec into Docker artifacts via Jinja2 templates. | |
| Takes a validated SnapshotSpec and produces the concrete files needed | |
| to boot a range: docker-compose.yml, Dockerfiles, nginx.conf, init.sql, | |
| iptables.rules, and any generated service/app payload files. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import logging | |
| import shlex | |
| from pathlib import Path | |
| from pathlib import PurePosixPath | |
| from typing import Any | |
| import jinja2 | |
| from open_range.builder.manifest_graph import runtime_contract_from_topology | |
| from open_range.builder.service_manifest import generate_service_specs | |
| from open_range.protocols import SnapshotSpec | |
| logger = logging.getLogger(__name__) | |
| # Template directory lives alongside this module | |
| _TEMPLATE_DIR = Path(__file__).parent / "templates" | |
| # Map of template filename -> output filename | |
| _TEMPLATE_MAP: dict[str, str] = { | |
| "docker-compose.yml.j2": "docker-compose.yml", | |
| "Dockerfile.web.j2": "Dockerfile.web", | |
| "Dockerfile.db.j2": "Dockerfile.db", | |
| "nginx.conf.j2": "nginx.conf", | |
| "init.sql.j2": "init.sql", | |
| "iptables.rules.j2": "iptables.rules", | |
| } | |
| PAYLOAD_ROOT_DIR = "rendered_files" | |
| PAYLOAD_MANIFEST_NAME = "file-payloads.json" | |
| class SnapshotRenderer: | |
| """Render Jinja2 templates from a SnapshotSpec to an output directory. | |
| Uses the templates in the ``templates/`` directory adjacent to this module | |
| to produce all Docker artifacts needed to boot a range. | |
| """ | |
| def __init__(self, template_dir: Path | None = None) -> None: | |
| """Initialize with an optional custom template directory.""" | |
| self.template_dir = template_dir or _TEMPLATE_DIR | |
| self.env = jinja2.Environment( | |
| loader=jinja2.FileSystemLoader(str(self.template_dir)), | |
| keep_trailing_newline=True, | |
| undefined=jinja2.Undefined, | |
| ) | |
| self.env.filters["shell_quote"] = shlex.quote | |
| def render(self, spec: SnapshotSpec, output_dir: Path) -> Path: | |
| """Render all templates and write artifacts to *output_dir*. | |
| Returns the output directory path. | |
| """ | |
| output_dir = Path(output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| context = _build_context(spec) | |
| logger.info( | |
| "SnapshotRenderer: rendering %d templates to %s", | |
| len(_TEMPLATE_MAP), | |
| output_dir, | |
| ) | |
| for template_name, output_name in _TEMPLATE_MAP.items(): | |
| template = self.env.get_template(template_name) | |
| rendered = template.render(**context) | |
| dest = output_dir / output_name | |
| dest.write_text(rendered, encoding="utf-8") | |
| logger.info("Rendered %s -> %s", template_name, dest) | |
| payload_manifest = self._render_payloads(spec, output_dir) | |
| manifest_path = output_dir / PAYLOAD_MANIFEST_NAME | |
| manifest_path.write_text( | |
| json.dumps(payload_manifest, indent=2, sort_keys=True), | |
| encoding="utf-8", | |
| ) | |
| logger.info("Rendered %d payload artifact(s) -> %s", len(payload_manifest), manifest_path) | |
| # Generate ServiceSpec entries from compose + topology | |
| self._build_service_specs(spec) | |
| logger.info( | |
| "SnapshotRenderer: rendering complete (%d templates, %d payloads, %d services)", | |
| len(_TEMPLATE_MAP), | |
| len(payload_manifest), | |
| len(spec.services), | |
| ) | |
| return output_dir | |
| def _build_service_specs(self, spec: SnapshotSpec) -> None: | |
| """Populate ``spec.services`` from compose and topology. | |
| Delegates to :func:`generate_service_specs` which maps Docker | |
| image names (or topology host names) to subprocess-mode daemon | |
| lifecycle declarations. Only runs if the spec does not already | |
| have services declared (idempotent). | |
| """ | |
| if spec.services: | |
| logger.debug("ServiceSpec entries already present — skipping generation") | |
| return | |
| svc_specs = generate_service_specs( | |
| compose=spec.compose, | |
| topology=spec.topology, | |
| ) | |
| spec.services = svc_specs | |
| if svc_specs: | |
| logger.info( | |
| "Generated %d ServiceSpec entries: %s", | |
| len(svc_specs), | |
| [s.daemon for s in svc_specs], | |
| ) | |
| def _render_payloads(self, spec: SnapshotSpec, output_dir: Path) -> dict[str, str]: | |
| payload_manifest: dict[str, str] = {} | |
| for key, content in spec.files.items(): | |
| rel_path = payload_relpath(key) | |
| dest = output_dir / rel_path | |
| dest.parent.mkdir(parents=True, exist_ok=True) | |
| dest.write_text(content, encoding="utf-8") | |
| payload_manifest[key] = str(rel_path) | |
| logger.info("Rendered payload %s -> %s", key, dest) | |
| return payload_manifest | |
| def payload_relpath(file_key: str) -> Path: | |
| """Return the artifact-bundle relative path for a SnapshotSpec.files entry.""" | |
| if file_key == "db:sql": | |
| return Path(PAYLOAD_ROOT_DIR) / "db" / "sql" / "generated.sql" | |
| if ":" not in file_key: | |
| raise ValueError(f"Invalid file payload key: {file_key}") | |
| container, raw_path = file_key.split(":", 1) | |
| safe_parts = [ | |
| part | |
| for part in PurePosixPath(raw_path).parts | |
| if part not in {"", "/", ".", ".."} | |
| ] | |
| if not safe_parts: | |
| raise ValueError(f"Invalid file payload path: {file_key}") | |
| return Path(PAYLOAD_ROOT_DIR) / container / Path(*safe_parts) | |
| def _build_context(spec: SnapshotSpec) -> dict[str, Any]: | |
| """Build the Jinja2 template context from a SnapshotSpec. | |
| Flattens and adapts the SnapshotSpec fields into the variable names | |
| expected by the templates. | |
| """ | |
| topology = spec.topology | |
| hosts_raw = topology.get("hosts", []) | |
| zones = topology.get("zones", {}) | |
| users = topology.get("users", []) | |
| runtime_contract = runtime_contract_from_topology(topology) | |
| # Build host objects with name, zone, networks, depends_on | |
| hosts = _build_hosts(hosts_raw, zones) | |
| host_names = [h["name"] for h in hosts] | |
| # Build network objects from zones | |
| networks = _build_networks(zones) | |
| # Zone -> CIDR mapping for iptables template | |
| zone_cidrs = _build_zone_cidrs(zones) | |
| # Firewall rules (from topology if present, else empty) | |
| firewall_rules = topology.get("firewall_rules", []) | |
| # Flags as dicts for templates | |
| flags = [f.model_dump() for f in spec.flags] | |
| # Detect vuln types for nginx conditional blocks | |
| vuln_types = {v.type for v in spec.truth_graph.vulns} | |
| vuln_injection_points = {v.injection_point for v in spec.truth_graph.vulns} | |
| # App files placeholder (templates reference app_files but we provide | |
| # an empty dict -- actual PHP files would be generated separately) | |
| app_files: dict[str, str] = {} | |
| if spec.files: | |
| app_files = spec.files | |
| # Determine which nginx endpoint blocks to enable. | |
| # Templates use `{% if X is defined %}` so we only include these keys | |
| # when they should be True (omitting = undefined = block not rendered). | |
| has_search = ( | |
| any("search" in ip or "q=" in ip for ip in vuln_injection_points) | |
| or "sqli" in vuln_types | |
| ) | |
| has_download = ( | |
| any("download" in ip or "file=" in ip for ip in vuln_injection_points) | |
| or "path_traversal" in vuln_types | |
| ) | |
| logger.debug( | |
| "_build_context: %d hosts, %d networks, search=%s, download=%s", | |
| len(hosts), | |
| len(networks), | |
| has_search, | |
| has_download, | |
| ) | |
| db_user = runtime_contract["db_user"] | |
| db_pass = runtime_contract["db_password"] | |
| context: dict[str, Any] = { | |
| # docker-compose.yml.j2 | |
| "snapshot_id": topology.get("snapshot_id", "generated"), | |
| "networks": networks, | |
| "hosts": hosts, | |
| "host_names": host_names, | |
| "db_host": runtime_contract["db_host"], | |
| "db_user": db_user, | |
| "db_pass": db_pass, | |
| "db_name": runtime_contract["db_name"], | |
| # db_password duplicates db_pass: Dockerfile.db.j2 uses db_pass, | |
| # docker-compose.yml.j2 uses db_password. Keep both for compat. | |
| "db_password": db_pass, | |
| "mysql_root_password": topology.get("mysql_root_password", _find_mysql_root_pass(users)), | |
| "domain": runtime_contract["domain"], | |
| "org_name": topology.get("org_name", "Corp"), | |
| "ldap_admin_pass": topology.get("ldap_admin_pass", "LdapAdm1n!"), | |
| "smb_shares": _find_smb_shares(spec), | |
| "smb_user": _find_smb_user(users), | |
| "smb_password": _find_smb_pass(users), | |
| "web_doc_root": runtime_contract["web_doc_root"], | |
| "web_config_path": runtime_contract["web_config_path"], | |
| "ldap_bind_dn": runtime_contract["ldap_bind_dn"], | |
| "ldap_bind_pw": runtime_contract["ldap_bind_pw"], | |
| "ldap_search_base_dn": runtime_contract["ldap_search_base_dn"], | |
| # Dockerfile.web.j2 | |
| "users": users, | |
| "app_files": app_files, | |
| "flags": flags, | |
| # nginx.conf.j2 | |
| "server_name": topology.get("domain", f"{runtime_contract['web_host']}.{runtime_contract['domain']}"), | |
| # iptables.rules.j2 | |
| "firewall_rules": firewall_rules, | |
| "zone_cidrs": zone_cidrs, | |
| } | |
| # Only include endpoint keys when enabled (templates use `is defined`) | |
| if has_search: | |
| context["search_endpoint"] = True | |
| if has_download: | |
| context["download_endpoint"] = True | |
| return context | |
| def _build_hosts( | |
| hosts_raw: list[str] | list[dict[str, Any]], | |
| zones: dict[str, list[str]], | |
| ) -> list[dict[str, Any]]: | |
| """Convert host list (strings or dicts) into template-ready dicts.""" | |
| # Build reverse map: host_name -> zone | |
| host_to_zone: dict[str, str] = {} | |
| for zone_name, zone_hosts in zones.items(): | |
| for h in zone_hosts: | |
| host_to_zone[h] = zone_name | |
| hosts = [] | |
| for h in hosts_raw: | |
| if isinstance(h, dict): | |
| name = h["name"] | |
| zone = h.get("zone", host_to_zone.get(name, "default")) | |
| networks = h.get("networks", [zone]) | |
| depends_on = h.get("depends_on", []) | |
| hosts.append( | |
| { | |
| "name": name, | |
| "zone": zone, | |
| "networks": networks, | |
| "depends_on": depends_on, | |
| } | |
| ) | |
| else: | |
| # Simple string host name | |
| zone = host_to_zone.get(h, "default") | |
| hosts.append( | |
| { | |
| "name": h, | |
| "zone": zone, | |
| "networks": [zone], | |
| "depends_on": [], | |
| } | |
| ) | |
| return hosts | |
| def _build_networks(zones: dict[str, list[str]]) -> list[dict[str, str]]: | |
| """Build network objects from zone definitions. | |
| Uses conventional CIDRs: dmz=10.0.1.0/24, internal=10.0.2.0/24, | |
| management=10.0.3.0/24. External gets no CIDR (bridge default). | |
| """ | |
| default_cidrs = { | |
| "dmz": "10.0.1.0/24", | |
| "internal": "10.0.2.0/24", | |
| "management": "10.0.3.0/24", | |
| } | |
| networks = [] | |
| for zone_name in zones: | |
| net: dict[str, str] = {"name": zone_name} | |
| if zone_name in default_cidrs: | |
| net["cidr"] = default_cidrs[zone_name] | |
| networks.append(net) | |
| return networks | |
| def _build_zone_cidrs(zones: dict[str, list[str]]) -> dict[str, str]: | |
| """Map zone names to CIDR blocks for iptables rules.""" | |
| default_cidrs = { | |
| "external": "0.0.0.0/0", | |
| "dmz": "10.0.1.0/24", | |
| "internal": "10.0.2.0/24", | |
| "management": "10.0.3.0/24", | |
| } | |
| return {z: default_cidrs.get(z, "0.0.0.0/0") for z in zones} | |
| def _find_db_user(users: list[dict[str, Any]]) -> str: | |
| """Find the database user from topology users, default to app_user.""" | |
| for u in users: | |
| hosts = u.get("hosts", []) | |
| if "db" in hosts and "admins" not in u.get("groups", []): | |
| return u.get("username", "app_user") | |
| return "app_user" | |
| def _find_db_pass(users: list[dict[str, Any]]) -> str: | |
| """Find the database user password.""" | |
| for u in users: | |
| hosts = u.get("hosts", []) | |
| if "db" in hosts and "admins" not in u.get("groups", []): | |
| return u.get("password", "AppUs3r!2024") | |
| return "AppUs3r!2024" | |
| def _find_mysql_root_pass(users: list[dict[str, Any]]) -> str: | |
| """Find MySQL root password from admin user or use default.""" | |
| for u in users: | |
| if u.get("username") == "admin" and "db" in u.get("hosts", []): | |
| return u.get("password", "r00tP@ss!") | |
| return "r00tP@ss!" | |
| def _find_smb_user(users: list[dict[str, Any]]) -> str: | |
| """Find the SMB/Samba user from topology users, default to smbuser.""" | |
| for u in users: | |
| hosts = u.get("hosts", []) | |
| if "files" in hosts and "admins" not in u.get("groups", []): | |
| return u.get("username", "smbuser") | |
| return "smbuser" | |
| def _find_smb_pass(users: list[dict[str, Any]]) -> str: | |
| """Find the SMB/Samba user password.""" | |
| for u in users: | |
| hosts = u.get("hosts", []) | |
| if "files" in hosts and "admins" not in u.get("groups", []): | |
| return u.get("password", "smbP@ss!") | |
| return "smbP@ss!" | |
| def _find_smb_shares(spec: SnapshotSpec) -> list[str]: | |
| """Extract Samba share names from snapshot files dict.""" | |
| shares: set[str] = set() | |
| for key in spec.files: | |
| if not key.startswith("files:"): | |
| continue | |
| path = key.split(":", 1)[1] | |
| if "/srv/shares/" in path: | |
| parts = path.split("/srv/shares/")[1].split("/") | |
| if parts: | |
| shares.add(parts[0]) | |
| return sorted(shares) or ["general"] | |