File size: 10,745 Bytes
7952f32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
"""Canonical graph schema.

The graph is the single source of truth for an in-progress program. Every
materialization is a deterministic function of (graph, template library).

Wire format mirrors the JSON shape documented in PROPOSAL.md §3.1, exactly:

    {
      "modules": [{"name": ..., "responsibility": ...}, ...],
      "nodes":   [{"name": ..., "module": ..., "signature": ...,
                   "body_template": ..., "body_template_args": {...},
                   "purity": ..., "error_policy": ..., "decl_order": ...}, ...],
      "edges":   [{"caller": "<module>.<name>",
                   "callee": "<module>.<name>",
                   "arg_mapping": [{"caller_arg": ..., "callee_param": ...}, ...]}, ...]
    }

This module enforces shape and well-formedness only. Higher-order invariants
(unique names, edge endpoints exist, no cycles, type-flow compatibility) are
enforced by the action dispatcher and the type engine, not the schema, so
that callers can build partial / invalid graphs and inspect why they fail.
"""

from __future__ import annotations

import hashlib
import json
from typing import Literal, Optional

from pydantic import BaseModel, ConfigDict, Field, field_validator

# ----------------------------------------------------------------------
# Enumerated tags
# ----------------------------------------------------------------------

# Responsibility tags constrain which kinds of nodes a module is allowed to
# host. The canonical set; new tags are added intentionally because tasks
# encode constraints against this vocabulary.
ResponsibilityTag = Literal[
    "io",
    "validation",
    "transform",
    "orchestration",
    "storage",
    "formatting",
    "lookup",
    "policy",
    "logging",
    "computation",
]

Purity = Literal["pure", "impure"]

# How a function handles errors in its body. "guard" means it includes a
# guard / try-except. "propagate" means it deliberately lets errors flow up.
# "none" is the default — no claim either way.
ErrorPolicy = Literal["guard", "propagate", "none"]


# ----------------------------------------------------------------------
# Atomic records
# ----------------------------------------------------------------------


class Module(BaseModel):
    """A declared module — one Python file at materialization time."""

    model_config = ConfigDict(extra="forbid", frozen=False)

    name: str = Field(..., min_length=1)
    responsibility: ResponsibilityTag

    @field_validator("name")
    @classmethod
    def _name_is_identifier(cls, v: str) -> str:
        if not v.isidentifier():
            raise ValueError(f"module name {v!r} is not a Python identifier")
        if v.startswith("_"):
            raise ValueError(f"module name {v!r} must not start with an underscore")
        return v


class Node(BaseModel):
    """A declared function. ``body_template`` may be unset until attach_body."""

    model_config = ConfigDict(extra="forbid", frozen=False)

    name: str = Field(..., min_length=1)
    module: str = Field(..., min_length=1)
    signature: str = Field(..., min_length=2)  # e.g., "(x: int) -> bool"
    body_template: Optional[str] = None
    body_template_args: dict[str, object] = Field(default_factory=dict)
    purity: Purity = "impure"
    error_policy: ErrorPolicy = "none"
    decl_order: int = 0

    @field_validator("name")
    @classmethod
    def _name_is_identifier(cls, v: str) -> str:
        if not v.isidentifier():
            raise ValueError(f"node name {v!r} is not a Python identifier")
        return v

    @field_validator("signature")
    @classmethod
    def _signature_shape(cls, v: str) -> str:
        # Cheap surface check; the type engine does the real parse.
        if not v.lstrip().startswith("("):
            raise ValueError(f"signature must start with '(': got {v!r}")
        if "->" not in v:
            raise ValueError(f"signature must include '->' return arrow: got {v!r}")
        return v

    # Convenience -----------------------------------------------------

    @property
    def qualified_name(self) -> str:
        """``<module>.<name>`` — the canonical address used on edges."""
        return f"{self.module}.{self.name}"


class ArgMapping(BaseModel):
    """How an edge wires a caller's argument to a callee's parameter."""

    model_config = ConfigDict(extra="forbid", frozen=False)

    caller_arg: str = Field(..., min_length=1)
    callee_param: str = Field(..., min_length=1)


class Edge(BaseModel):
    """A CALLS edge. Endpoints are qualified node names ``<module>.<name>``."""

    model_config = ConfigDict(extra="forbid", frozen=False)

    caller: str = Field(..., min_length=3)
    callee: str = Field(..., min_length=3)
    arg_mapping: list[ArgMapping] = Field(default_factory=list)

    @field_validator("caller", "callee")
    @classmethod
    def _qualified(cls, v: str) -> str:
        if v.count(".") != 1:
            raise ValueError(
                f"edge endpoint {v!r} is not qualified (expected '<module>.<name>')"
            )
        mod, name = v.split(".")
        if not mod.isidentifier() or not name.isidentifier():
            raise ValueError(f"edge endpoint {v!r} has non-identifier parts")
        return v


# ----------------------------------------------------------------------
# Graph
# ----------------------------------------------------------------------


class Graph(BaseModel):
    """Canonical graph state. Mutable; cloned via ``snapshot``/``restore``."""

    model_config = ConfigDict(extra="forbid", frozen=False)

    modules: list[Module] = Field(default_factory=list)
    nodes: list[Node] = Field(default_factory=list)
    edges: list[Edge] = Field(default_factory=list)

    # ----- lookup ----------------------------------------------------

    def find_module(self, name: str) -> Optional[Module]:
        for m in self.modules:
            if m.name == name:
                return m
        return None

    def find_node(self, name: str, module: str) -> Optional[Node]:
        for n in self.nodes:
            if n.name == name and n.module == module:
                return n
        return None

    def find_node_qualified(self, qualified: str) -> Optional[Node]:
        if qualified.count(".") != 1:
            return None
        mod, nm = qualified.split(".")
        return self.find_node(nm, mod)

    def find_edge(self, caller: str, callee: str) -> Optional[Edge]:
        for e in self.edges:
            if e.caller == caller and e.callee == callee:
                return e
        return None

    def nodes_in_module(self, module: str) -> list[Node]:
        return [n for n in self.nodes if n.module == module]

    def callers_of(self, qualified: str) -> list[str]:
        return [e.caller for e in self.edges if e.callee == qualified]

    def callees_of(self, qualified: str) -> list[str]:
        return [e.callee for e in self.edges if e.caller == qualified]

    def fan_in(self, qualified: str) -> int:
        return len(self.callers_of(qualified))

    def fan_out(self, qualified: str) -> int:
        return len(self.callees_of(qualified))

    # ----- structural derivations ------------------------------------

    def import_edges(self) -> set[tuple[str, str]]:
        """Set of (caller_module, callee_module) pairs from cross-module edges."""
        out: set[tuple[str, str]] = set()
        for e in self.edges:
            cm = e.caller.split(".")[0]
            tm = e.callee.split(".")[0]
            if cm != tm:
                out.add((cm, tm))
        return out

    def has_module_cycle(self) -> bool:
        """True iff the cross-module import graph contains a directed cycle."""
        adj: dict[str, set[str]] = {m.name: set() for m in self.modules}
        for src, dst in self.import_edges():
            adj.setdefault(src, set()).add(dst)
            adj.setdefault(dst, set())
        WHITE, GRAY, BLACK = 0, 1, 2
        color: dict[str, int] = {k: WHITE for k in adj}

        def visit(u: str) -> bool:
            color[u] = GRAY
            for v in adj.get(u, ()):
                if color[v] == GRAY:
                    return True
                if color[v] == WHITE and visit(v):
                    return True
            color[u] = BLACK
            return False

        return any(color[u] == WHITE and visit(u) for u in adj)

    def call_graph_depth(self) -> int:
        """Longest path length (in edges) in the function call DAG.

        If the call graph is cyclic, returns the special value -1 (callers
        should treat this as an invariant violation).
        """
        adj: dict[str, list[str]] = {n.qualified_name: [] for n in self.nodes}
        for e in self.edges:
            adj.setdefault(e.caller, []).append(e.callee)
            adj.setdefault(e.callee, [])
        memo: dict[str, int] = {}
        ON_STACK = -2

        def dfs(u: str) -> int:
            if u in memo:
                if memo[u] == ON_STACK:
                    return -1
                return memo[u]
            memo[u] = ON_STACK
            best = 0
            for v in adj.get(u, ()):
                d = dfs(v)
                if d == -1:
                    return -1
                best = max(best, d + 1)
            memo[u] = best
            return best

        results = [dfs(u) for u in adj]
        if any(r == -1 for r in results):
            return -1
        return max(results, default=0)

    # ----- copying / hashing -----------------------------------------

    def snapshot(self) -> "Graph":
        """Deep copy. Used by the dispatcher for atomic action rollback."""
        return self.model_copy(deep=True)

    def structural_hash(self) -> str:
        """Stable SHA-256 over a canonical JSON projection.

        Insensitive to list ordering on the dimensions where order is not
        semantically meaningful (modules, nodes), but sensitive to
        ``decl_order`` because that affects materialized output.
        """
        canon: dict[str, object] = {
            "modules": sorted(
                [m.model_dump() for m in self.modules],
                key=lambda d: d["name"],
            ),
            "nodes": sorted(
                [n.model_dump() for n in self.nodes],
                key=lambda d: (d["module"], d["name"]),
            ),
            "edges": sorted(
                [e.model_dump() for e in self.edges],
                key=lambda d: (d["caller"], d["callee"]),
            ),
        }
        blob = json.dumps(canon, sort_keys=True, default=str).encode("utf-8")
        return hashlib.sha256(blob).hexdigest()

    # ----- factories -------------------------------------------------

    @classmethod
    def empty(cls) -> "Graph":
        return cls()