"""Materialize a :class:`Graph` into a dict of ``{filename: source}``. Determinism guarantees: * One file per module, named ``.py``. * Within a file, functions emitted in :attr:`Node.decl_order`. * Imports sorted: stdlib first (alpha), then ``from import `` (alpha by module, alpha by name). * Pattern constants emitted only if used, in alpha order. * Out-edges of a node iterated in insertion order, which matters for ``sequential_calls`` and ``try_call_with_fallback`` semantics. The orchestrator is a pure function: same graph in, same source out. """ from __future__ import annotations from collections import defaultdict from typing import Iterable from graphforge.graph.schema import Edge, Graph, Node from graphforge.materializer import codegen, patterns HEADER = '"""Auto-generated by graphforge.materializer. Do not edit by hand."""\n' FUTURE = "from __future__ import annotations\n" # ---- helpers --------------------------------------------------------- def _out_edges_in_order(graph: Graph, qualified: str) -> list[Edge]: """Out-edges of ``qualified`` in insertion order.""" return [e for e in graph.edges if e.caller == qualified] def _nodes_by_module(graph: Graph) -> dict[str, list[Node]]: """Map module-name -> nodes in decl_order.""" by_mod: dict[str, list[Node]] = defaultdict(list) for n in graph.nodes: by_mod[n.module].append(n) for ns in by_mod.values(): ns.sort(key=lambda n: (n.decl_order, n.name)) return by_mod def _cross_module_imports(graph: Graph, module: str) -> list[tuple[str, str]]: """``[(callee_module, callee_name), ...]`` needed by ``module``.""" pairs: set[tuple[str, str]] = set() for e in graph.edges: caller_mod = e.caller.split(".", 1)[0] if caller_mod != module: continue callee_mod, callee_name = e.callee.split(".", 1) if callee_mod != module: pairs.add((callee_mod, callee_name)) return sorted(pairs) def _stdlib_imports_for(nodes: Iterable[Node]) -> list[str]: """Stdlib imports the templates in this module require.""" needed: set[str] = set() for n in nodes: needed |= codegen.template_imports(n.body_template) return sorted(needed) def _patterns_used_by(nodes: Iterable[Node]) -> list[str]: """Named patterns referenced by validate_with_regex nodes in this module.""" used: set[str] = set() for n in nodes: if n.body_template == "validate_with_regex": name = str(n.body_template_args.get("pattern", "")) if patterns.get_pattern(name) is not None: used.add(name) return sorted(used) # ---- core ------------------------------------------------------------ def materialize(graph: Graph) -> dict[str, str]: """Project ``graph`` to a ``{filename: source}`` map. Modules with zero nodes are still emitted as empty files (just header + future import) so that downstream import-resolution sees them. """ by_mod = _nodes_by_module(graph) files: dict[str, str] = {} for module in graph.modules: nodes = by_mod.get(module.name, []) files[f"{module.name}.py"] = _render_module(graph, module.name, nodes) return files def _render_module(graph: Graph, module_name: str, nodes: list[Node]) -> str: parts: list[str] = [HEADER, FUTURE, "\n"] # Stdlib imports. for imp in _stdlib_imports_for(nodes): parts.append(f"import {imp}\n") # Cross-module function imports. for callee_mod, callee_name in _cross_module_imports(graph, module_name): parts.append(f"from {callee_mod} import {callee_name}\n") if ( any(_stdlib_imports_for(nodes)) or _cross_module_imports(graph, module_name) ): parts.append("\n") # Pattern constants used in this module. We emit a plain string literal # (not a raw-string-prefixed one) because ``repr()`` already produces a # valid Python string literal — wrapping it in ``r"..."`` would double # the backslashes and break regex metacharacters like ``\s`` and ``\d``. used_patterns = _patterns_used_by(nodes) for name in used_patterns: regex = patterns.get_pattern(name) constant = patterns.constant_name(name) parts.append(f"{constant} = {regex!r}\n") if used_patterns: parts.append("\n") # Functions. for i, node in enumerate(nodes): out_edges = _out_edges_in_order(graph, node.qualified_name) body = codegen.render_body(node, out_edges, graph) parts.append(f"def {node.name}{node.signature}:\n{body}\n") if i != len(nodes) - 1: parts.append("\n") source = "".join(parts) # Ensure exactly one trailing newline. return source.rstrip("\n") + "\n"