File size: 14,065 Bytes
13ed231
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
"""
ExecutionEncoder: Graph-Based Transformer for Execution Plan Encoding

This module implements the transformer-based encoder that maps
(plan_graph, provenance_metadata) → z_e ∈ R^1024.

The ExecutionEncoder is the complementary half of the JEPA dual-encoder architecture,
encoding proposed execution plans into the same latent space as governance policies
to enable energy-based security validation.

Architecture:
- Graph Neural Network for tool-call dependency encoding
- Provenance-aware attention mechanism
- Scope metadata integration
- Differentiable for end-to-end training with energy functions

References:
- Graph Attention Networks: https://arxiv.org/abs/1710.10903
- Relational Graph Convolutional Networks: https://arxiv.org/abs/1703.06103
"""

from enum import IntEnum
from typing import Any

import torch
import torch.nn as nn
import torch.nn.functional as F
from pydantic import BaseModel, Field, field_validator


class TrustTier(IntEnum):
    """Trust levels for data provenance (as per Dawn Song's workstream)."""
    INTERNAL = 1  # System instructions, internal databases
    SIGNED_PARTNER = 2  # Verified external sources
    PUBLIC_WEB = 3  # Untrusted retrieval (RAG, web scraping)


class ToolCallNode(BaseModel):
    """A single tool invocation in the execution plan graph."""
    tool_name: str = Field(..., min_length=1, description="Name of the tool being invoked")
    arguments: dict[str, Any] = Field(default_factory=dict, description="Tool arguments")

    # Provenance metadata
    provenance_tier: TrustTier = Field(default=TrustTier.INTERNAL, description="Trust tier of instruction source")
    provenance_hash: str | None = Field(default=None, description="Cryptographic hash of source")

    # Scope metadata
    scope_volume: int = Field(default=1, ge=1, description="Data volume (rows, records, files)")
    scope_sensitivity: int = Field(default=1, ge=1, le=5, description="Sensitivity level (1=public, 5=critical)")

    # Graph metadata
    node_id: str = Field(..., description="Unique node identifier")

    @field_validator('provenance_tier', mode='before')
    @classmethod
    def parse_trust_tier(cls, v):
        """Parse trust tier from int or TrustTier."""
        if isinstance(v, int):
            return TrustTier(v)
        return v


class ExecutionPlan(BaseModel):
    """Complete execution plan represented as a typed tool-call graph."""
    nodes: list[ToolCallNode] = Field(..., min_length=1, description="Tool invocation nodes")
    edges: list[tuple[str, str]] = Field(default_factory=list, description="Data flow edges (src_id, dst_id)")

    @field_validator('edges')
    @classmethod
    def validate_edges(cls, v, info):
        """Ensure edge endpoints reference valid nodes."""
        if 'nodes' not in info.data:
            return v

        node_ids = {node.node_id for node in info.data['nodes']}
        for src, dst in v:
            if src not in node_ids or dst not in node_ids:
                raise ValueError(f"Edge ({src}, {dst}) references non-existent node")
        return v


class ProvenanceEmbedding(nn.Module):
    """Embeds provenance metadata (trust tier + cryptographic hash)."""

    def __init__(self, hidden_dim: int, num_tiers: int = 3):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.tier_embedding = nn.Embedding(num_tiers + 1, hidden_dim)  # +1 for padding
        self.scope_projection = nn.Linear(2, hidden_dim)  # volume + sensitivity
        self.fusion = nn.Linear(hidden_dim * 2, hidden_dim)

    def forward(
        self,
        tier_indices: torch.Tensor,
        scope_volume: torch.Tensor,
        scope_sensitivity: torch.Tensor
    ) -> torch.Tensor:
        """Combine provenance tier and scope metadata."""
        tier_emb = self.tier_embedding(tier_indices)

        # Log-scale volume to handle wide range (1 to 1M+)
        log_volume = torch.log1p(scope_volume.float()).unsqueeze(-1)
        sensitivity = scope_sensitivity.float().unsqueeze(-1)
        scope_features = torch.cat([log_volume, sensitivity], dim=-1)
        scope_emb = self.scope_projection(scope_features)

        combined = torch.cat([tier_emb, scope_emb], dim=-1)
        return self.fusion(combined)


class GraphAttention(nn.Module):
    """
    Graph Attention layer for encoding tool-call dependencies.
    Implements message passing with edge-aware attention.
    """

    def __init__(
        self,
        hidden_dim: int,
        num_heads: int = 8,
        dropout: float = 0.1
    ):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads

        assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads"

        self.q_proj = nn.Linear(hidden_dim, hidden_dim)
        self.k_proj = nn.Linear(hidden_dim, hidden_dim)
        self.v_proj = nn.Linear(hidden_dim, hidden_dim)
        self.out_proj = nn.Linear(hidden_dim, hidden_dim)

        self.dropout = nn.Dropout(dropout)
        self.scale = self.head_dim ** -0.5

    def forward(
        self,
        x: torch.Tensor,
        adjacency: torch.Tensor
    ) -> torch.Tensor:
        """
        Apply graph attention.

        Args:
            x: Node features [batch_size, num_nodes, hidden_dim]
            adjacency: Adjacency matrix [batch_size, num_nodes, num_nodes]
                      1 = edge exists, 0 = no edge
        """
        batch_size, num_nodes, _ = x.shape

        q = self.q_proj(x).view(batch_size, num_nodes, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(batch_size, num_nodes, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(batch_size, num_nodes, self.num_heads, self.head_dim).transpose(1, 2)

        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale

        # Mask attention to respect graph structure
        # Also add self-loops (diagonal) for residual connections
        mask = adjacency.unsqueeze(1)  # [batch, 1, nodes, nodes]
        eye = torch.eye(num_nodes, device=x.device).unsqueeze(0).unsqueeze(0)
        mask = torch.maximum(mask, eye)  # Add self-loops

        scores = scores.masked_fill(mask == 0, float('-inf'))
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        out = torch.matmul(attn_weights, v)
        out = out.transpose(1, 2).contiguous().view(batch_size, num_nodes, self.hidden_dim)

        return self.out_proj(out)


class GraphTransformerBlock(nn.Module):
    """Transformer block with graph-aware attention."""

    def __init__(
        self,
        hidden_dim: int,
        num_heads: int,
        dropout: float = 0.1
    ):
        super().__init__()
        self.attention = GraphAttention(hidden_dim, num_heads, dropout)
        self.norm1 = nn.LayerNorm(hidden_dim)

        self.ffn = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim * 4, hidden_dim),
            nn.Dropout(dropout)
        )
        self.norm2 = nn.LayerNorm(hidden_dim)

    def forward(
        self,
        x: torch.Tensor,
        adjacency: torch.Tensor
    ) -> torch.Tensor:
        """Apply graph transformer block."""
        x = x + self.attention(self.norm1(x), adjacency)
        x = x + self.ffn(self.norm2(x))
        return x


class ExecutionEncoder(nn.Module):
    """
    Graph-based transformer encoder mapping execution plans to z_e ∈ R^1024.

    Encodes:
    - Tool invocation sequences
    - Data flow dependencies (graph edges)
    - Provenance metadata (trust tiers)
    - Scope metadata (volume + sensitivity)

    Performance targets:
        - Latency: <100ms on CPU (pairs with GovernanceEncoder's 98ms)
        - Memory: <500MB
        - Differentiable: Yes
    """

    def __init__(
        self,
        latent_dim: int = 1024,
        hidden_dim: int = 512,
        num_layers: int = 4,
        num_heads: int = 8,
        max_nodes: int = 64,
        dropout: float = 0.1,
        vocab_size: int = 10000
    ):
        super().__init__()

        self.latent_dim = latent_dim
        self.hidden_dim = hidden_dim
        self.max_nodes = max_nodes

        # Token embeddings for tool names and arguments
        self.token_embedding = nn.Embedding(vocab_size, hidden_dim)
        self.position_embedding = nn.Embedding(max_nodes, hidden_dim)

        # Provenance and scope embeddings
        self.provenance_embedding = ProvenanceEmbedding(hidden_dim)

        # Graph transformer layers
        self.layers = nn.ModuleList([
            GraphTransformerBlock(hidden_dim, num_heads, dropout)
            for _ in range(num_layers)
        ])

        # Pooling and projection
        self.attention_pool = nn.Linear(hidden_dim, 1)
        self.projection = nn.Sequential(
            nn.Linear(hidden_dim, latent_dim * 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(latent_dim * 2, latent_dim),
            nn.LayerNorm(latent_dim)
        )

        self.input_norm = nn.LayerNorm(hidden_dim)

    def _tokenize(self, text: str) -> int:
        """
        Hash-based tokenization (v0.1.0).

        Future: Replace with BPE tokenizer (v0.2.0) to reduce collisions.
        """
        return hash(text) % 10000

    def _create_adjacency_matrix(
        self,
        num_nodes: int,
        edges: list[tuple[int, int]],
        device: torch.device
    ) -> torch.Tensor:
        """Build adjacency matrix from edge list."""
        adjacency = torch.zeros(num_nodes, num_nodes, device=device)
        for src, dst in edges:
            if src < num_nodes and dst < num_nodes:
                adjacency[src, dst] = 1
        return adjacency

    def forward(
        self,
        plan: ExecutionPlan | dict[str, Any]
    ) -> torch.Tensor:
        """
        Encode execution plan into latent vector.

        Args:
            plan: ExecutionPlan or dict conforming to ExecutionPlan schema

        Returns:
            z_e: Latent vector [1, latent_dim]
        """
        # Validate and parse input
        if not isinstance(plan, ExecutionPlan):
            plan = ExecutionPlan(**plan)

        nodes = plan.nodes
        edges = plan.edges

        # Build node ID mapping
        node_id_to_idx = {node.node_id: i for i, node in enumerate(nodes)}
        edge_indices = [(node_id_to_idx[src], node_id_to_idx[dst]) for src, dst in edges]

        # Pad or truncate to max_nodes
        num_nodes = min(len(nodes), self.max_nodes)
        nodes = nodes[:num_nodes]

        # Tokenize tool names and arguments
        tool_tokens = []
        for node in nodes:
            # Combine tool name + serialized args for richer representation
            arg_str = ",".join(f"{k}={v}" for k, v in sorted(node.arguments.items()))
            combined = f"{node.tool_name}({arg_str})"
            tool_tokens.append(self._tokenize(combined))

        # Pad tokens
        if len(tool_tokens) < self.max_nodes:
            tool_tokens.extend([0] * (self.max_nodes - len(tool_tokens)))

        # Infer device from model parameters so tensors land on the right device (cpu/mps/cuda)
        device = next(self.parameters()).device

        # Convert to tensors
        token_ids = torch.tensor(tool_tokens[:self.max_nodes], device=device).unsqueeze(0)
        position_ids = torch.arange(self.max_nodes, device=device).unsqueeze(0)

        # Provenance and scope metadata
        tier_indices = torch.tensor([node.provenance_tier for node in nodes] + [0] * (self.max_nodes - num_nodes), device=device).unsqueeze(0)
        scope_volume = torch.tensor([node.scope_volume for node in nodes] + [1] * (self.max_nodes - num_nodes), device=device).unsqueeze(0)
        scope_sensitivity = torch.tensor([node.scope_sensitivity for node in nodes] + [1] * (self.max_nodes - num_nodes), device=device).unsqueeze(0)

        # Build adjacency matrix
        adjacency = self._create_adjacency_matrix(
            self.max_nodes,
            edge_indices,
            device
        ).unsqueeze(0)

        # Embed tokens
        token_emb = self.token_embedding(token_ids)
        pos_emb = self.position_embedding(position_ids)
        prov_emb = self.provenance_embedding(tier_indices, scope_volume, scope_sensitivity)

        # Combine embeddings
        x = token_emb + pos_emb + prov_emb
        x = self.input_norm(x)

        # Apply graph transformer layers
        for layer in self.layers:
            x = layer(x, adjacency)

        # Attention pooling over nodes
        attn_scores = self.attention_pool(x).squeeze(-1)
        attn_weights = F.softmax(attn_scores, dim=-1).unsqueeze(1)
        pooled = torch.matmul(attn_weights, x).squeeze(1)

        # Project to latent space
        z_e = self.projection(pooled)

        return z_e

    def encode_batch(self, plans: list[ExecutionPlan]) -> torch.Tensor:
        """
        Batch encoding of multiple execution plans.

        Args:
            plans: List of ExecutionPlan objects

        Returns:
            z_e: Latent vectors [batch_size, latent_dim]
        """
        latents = [self.forward(plan) for plan in plans]
        return torch.cat(latents, dim=0)


def create_execution_encoder(
    latent_dim: int = 1024,
    checkpoint_path: str | None = None,
    device: str = "cpu"
) -> ExecutionEncoder:
    """
    Factory function to create ExecutionEncoder.

    Args:
        latent_dim: Dimension of output latent vector (must match GovernanceEncoder)
        checkpoint_path: Optional path to pretrained weights
        device: Device to load model on

    Returns:
        Initialized ExecutionEncoder in inference mode
    """
    model = ExecutionEncoder(latent_dim=latent_dim)

    if checkpoint_path is not None:
        model.load_state_dict(torch.load(checkpoint_path, map_location=device, weights_only=True))

    model = model.to(device)
    model.training = False  # Set to inference mode

    return model