File size: 14,065 Bytes
13ed231 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 | """
ExecutionEncoder: Graph-Based Transformer for Execution Plan Encoding
This module implements the transformer-based encoder that maps
(plan_graph, provenance_metadata) → z_e ∈ R^1024.
The ExecutionEncoder is the complementary half of the JEPA dual-encoder architecture,
encoding proposed execution plans into the same latent space as governance policies
to enable energy-based security validation.
Architecture:
- Graph Neural Network for tool-call dependency encoding
- Provenance-aware attention mechanism
- Scope metadata integration
- Differentiable for end-to-end training with energy functions
References:
- Graph Attention Networks: https://arxiv.org/abs/1710.10903
- Relational Graph Convolutional Networks: https://arxiv.org/abs/1703.06103
"""
from enum import IntEnum
from typing import Any
import torch
import torch.nn as nn
import torch.nn.functional as F
from pydantic import BaseModel, Field, field_validator
class TrustTier(IntEnum):
"""Trust levels for data provenance (as per Dawn Song's workstream)."""
INTERNAL = 1 # System instructions, internal databases
SIGNED_PARTNER = 2 # Verified external sources
PUBLIC_WEB = 3 # Untrusted retrieval (RAG, web scraping)
class ToolCallNode(BaseModel):
"""A single tool invocation in the execution plan graph."""
tool_name: str = Field(..., min_length=1, description="Name of the tool being invoked")
arguments: dict[str, Any] = Field(default_factory=dict, description="Tool arguments")
# Provenance metadata
provenance_tier: TrustTier = Field(default=TrustTier.INTERNAL, description="Trust tier of instruction source")
provenance_hash: str | None = Field(default=None, description="Cryptographic hash of source")
# Scope metadata
scope_volume: int = Field(default=1, ge=1, description="Data volume (rows, records, files)")
scope_sensitivity: int = Field(default=1, ge=1, le=5, description="Sensitivity level (1=public, 5=critical)")
# Graph metadata
node_id: str = Field(..., description="Unique node identifier")
@field_validator('provenance_tier', mode='before')
@classmethod
def parse_trust_tier(cls, v):
"""Parse trust tier from int or TrustTier."""
if isinstance(v, int):
return TrustTier(v)
return v
class ExecutionPlan(BaseModel):
"""Complete execution plan represented as a typed tool-call graph."""
nodes: list[ToolCallNode] = Field(..., min_length=1, description="Tool invocation nodes")
edges: list[tuple[str, str]] = Field(default_factory=list, description="Data flow edges (src_id, dst_id)")
@field_validator('edges')
@classmethod
def validate_edges(cls, v, info):
"""Ensure edge endpoints reference valid nodes."""
if 'nodes' not in info.data:
return v
node_ids = {node.node_id for node in info.data['nodes']}
for src, dst in v:
if src not in node_ids or dst not in node_ids:
raise ValueError(f"Edge ({src}, {dst}) references non-existent node")
return v
class ProvenanceEmbedding(nn.Module):
"""Embeds provenance metadata (trust tier + cryptographic hash)."""
def __init__(self, hidden_dim: int, num_tiers: int = 3):
super().__init__()
self.hidden_dim = hidden_dim
self.tier_embedding = nn.Embedding(num_tiers + 1, hidden_dim) # +1 for padding
self.scope_projection = nn.Linear(2, hidden_dim) # volume + sensitivity
self.fusion = nn.Linear(hidden_dim * 2, hidden_dim)
def forward(
self,
tier_indices: torch.Tensor,
scope_volume: torch.Tensor,
scope_sensitivity: torch.Tensor
) -> torch.Tensor:
"""Combine provenance tier and scope metadata."""
tier_emb = self.tier_embedding(tier_indices)
# Log-scale volume to handle wide range (1 to 1M+)
log_volume = torch.log1p(scope_volume.float()).unsqueeze(-1)
sensitivity = scope_sensitivity.float().unsqueeze(-1)
scope_features = torch.cat([log_volume, sensitivity], dim=-1)
scope_emb = self.scope_projection(scope_features)
combined = torch.cat([tier_emb, scope_emb], dim=-1)
return self.fusion(combined)
class GraphAttention(nn.Module):
"""
Graph Attention layer for encoding tool-call dependencies.
Implements message passing with edge-aware attention.
"""
def __init__(
self,
hidden_dim: int,
num_heads: int = 8,
dropout: float = 0.1
):
super().__init__()
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.head_dim = hidden_dim // num_heads
assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads"
self.q_proj = nn.Linear(hidden_dim, hidden_dim)
self.k_proj = nn.Linear(hidden_dim, hidden_dim)
self.v_proj = nn.Linear(hidden_dim, hidden_dim)
self.out_proj = nn.Linear(hidden_dim, hidden_dim)
self.dropout = nn.Dropout(dropout)
self.scale = self.head_dim ** -0.5
def forward(
self,
x: torch.Tensor,
adjacency: torch.Tensor
) -> torch.Tensor:
"""
Apply graph attention.
Args:
x: Node features [batch_size, num_nodes, hidden_dim]
adjacency: Adjacency matrix [batch_size, num_nodes, num_nodes]
1 = edge exists, 0 = no edge
"""
batch_size, num_nodes, _ = x.shape
q = self.q_proj(x).view(batch_size, num_nodes, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(batch_size, num_nodes, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(batch_size, num_nodes, self.num_heads, self.head_dim).transpose(1, 2)
scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
# Mask attention to respect graph structure
# Also add self-loops (diagonal) for residual connections
mask = adjacency.unsqueeze(1) # [batch, 1, nodes, nodes]
eye = torch.eye(num_nodes, device=x.device).unsqueeze(0).unsqueeze(0)
mask = torch.maximum(mask, eye) # Add self-loops
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
out = torch.matmul(attn_weights, v)
out = out.transpose(1, 2).contiguous().view(batch_size, num_nodes, self.hidden_dim)
return self.out_proj(out)
class GraphTransformerBlock(nn.Module):
"""Transformer block with graph-aware attention."""
def __init__(
self,
hidden_dim: int,
num_heads: int,
dropout: float = 0.1
):
super().__init__()
self.attention = GraphAttention(hidden_dim, num_heads, dropout)
self.norm1 = nn.LayerNorm(hidden_dim)
self.ffn = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim * 4),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim * 4, hidden_dim),
nn.Dropout(dropout)
)
self.norm2 = nn.LayerNorm(hidden_dim)
def forward(
self,
x: torch.Tensor,
adjacency: torch.Tensor
) -> torch.Tensor:
"""Apply graph transformer block."""
x = x + self.attention(self.norm1(x), adjacency)
x = x + self.ffn(self.norm2(x))
return x
class ExecutionEncoder(nn.Module):
"""
Graph-based transformer encoder mapping execution plans to z_e ∈ R^1024.
Encodes:
- Tool invocation sequences
- Data flow dependencies (graph edges)
- Provenance metadata (trust tiers)
- Scope metadata (volume + sensitivity)
Performance targets:
- Latency: <100ms on CPU (pairs with GovernanceEncoder's 98ms)
- Memory: <500MB
- Differentiable: Yes
"""
def __init__(
self,
latent_dim: int = 1024,
hidden_dim: int = 512,
num_layers: int = 4,
num_heads: int = 8,
max_nodes: int = 64,
dropout: float = 0.1,
vocab_size: int = 10000
):
super().__init__()
self.latent_dim = latent_dim
self.hidden_dim = hidden_dim
self.max_nodes = max_nodes
# Token embeddings for tool names and arguments
self.token_embedding = nn.Embedding(vocab_size, hidden_dim)
self.position_embedding = nn.Embedding(max_nodes, hidden_dim)
# Provenance and scope embeddings
self.provenance_embedding = ProvenanceEmbedding(hidden_dim)
# Graph transformer layers
self.layers = nn.ModuleList([
GraphTransformerBlock(hidden_dim, num_heads, dropout)
for _ in range(num_layers)
])
# Pooling and projection
self.attention_pool = nn.Linear(hidden_dim, 1)
self.projection = nn.Sequential(
nn.Linear(hidden_dim, latent_dim * 2),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(latent_dim * 2, latent_dim),
nn.LayerNorm(latent_dim)
)
self.input_norm = nn.LayerNorm(hidden_dim)
def _tokenize(self, text: str) -> int:
"""
Hash-based tokenization (v0.1.0).
Future: Replace with BPE tokenizer (v0.2.0) to reduce collisions.
"""
return hash(text) % 10000
def _create_adjacency_matrix(
self,
num_nodes: int,
edges: list[tuple[int, int]],
device: torch.device
) -> torch.Tensor:
"""Build adjacency matrix from edge list."""
adjacency = torch.zeros(num_nodes, num_nodes, device=device)
for src, dst in edges:
if src < num_nodes and dst < num_nodes:
adjacency[src, dst] = 1
return adjacency
def forward(
self,
plan: ExecutionPlan | dict[str, Any]
) -> torch.Tensor:
"""
Encode execution plan into latent vector.
Args:
plan: ExecutionPlan or dict conforming to ExecutionPlan schema
Returns:
z_e: Latent vector [1, latent_dim]
"""
# Validate and parse input
if not isinstance(plan, ExecutionPlan):
plan = ExecutionPlan(**plan)
nodes = plan.nodes
edges = plan.edges
# Build node ID mapping
node_id_to_idx = {node.node_id: i for i, node in enumerate(nodes)}
edge_indices = [(node_id_to_idx[src], node_id_to_idx[dst]) for src, dst in edges]
# Pad or truncate to max_nodes
num_nodes = min(len(nodes), self.max_nodes)
nodes = nodes[:num_nodes]
# Tokenize tool names and arguments
tool_tokens = []
for node in nodes:
# Combine tool name + serialized args for richer representation
arg_str = ",".join(f"{k}={v}" for k, v in sorted(node.arguments.items()))
combined = f"{node.tool_name}({arg_str})"
tool_tokens.append(self._tokenize(combined))
# Pad tokens
if len(tool_tokens) < self.max_nodes:
tool_tokens.extend([0] * (self.max_nodes - len(tool_tokens)))
# Infer device from model parameters so tensors land on the right device (cpu/mps/cuda)
device = next(self.parameters()).device
# Convert to tensors
token_ids = torch.tensor(tool_tokens[:self.max_nodes], device=device).unsqueeze(0)
position_ids = torch.arange(self.max_nodes, device=device).unsqueeze(0)
# Provenance and scope metadata
tier_indices = torch.tensor([node.provenance_tier for node in nodes] + [0] * (self.max_nodes - num_nodes), device=device).unsqueeze(0)
scope_volume = torch.tensor([node.scope_volume for node in nodes] + [1] * (self.max_nodes - num_nodes), device=device).unsqueeze(0)
scope_sensitivity = torch.tensor([node.scope_sensitivity for node in nodes] + [1] * (self.max_nodes - num_nodes), device=device).unsqueeze(0)
# Build adjacency matrix
adjacency = self._create_adjacency_matrix(
self.max_nodes,
edge_indices,
device
).unsqueeze(0)
# Embed tokens
token_emb = self.token_embedding(token_ids)
pos_emb = self.position_embedding(position_ids)
prov_emb = self.provenance_embedding(tier_indices, scope_volume, scope_sensitivity)
# Combine embeddings
x = token_emb + pos_emb + prov_emb
x = self.input_norm(x)
# Apply graph transformer layers
for layer in self.layers:
x = layer(x, adjacency)
# Attention pooling over nodes
attn_scores = self.attention_pool(x).squeeze(-1)
attn_weights = F.softmax(attn_scores, dim=-1).unsqueeze(1)
pooled = torch.matmul(attn_weights, x).squeeze(1)
# Project to latent space
z_e = self.projection(pooled)
return z_e
def encode_batch(self, plans: list[ExecutionPlan]) -> torch.Tensor:
"""
Batch encoding of multiple execution plans.
Args:
plans: List of ExecutionPlan objects
Returns:
z_e: Latent vectors [batch_size, latent_dim]
"""
latents = [self.forward(plan) for plan in plans]
return torch.cat(latents, dim=0)
def create_execution_encoder(
latent_dim: int = 1024,
checkpoint_path: str | None = None,
device: str = "cpu"
) -> ExecutionEncoder:
"""
Factory function to create ExecutionEncoder.
Args:
latent_dim: Dimension of output latent vector (must match GovernanceEncoder)
checkpoint_path: Optional path to pretrained weights
device: Device to load model on
Returns:
Initialized ExecutionEncoder in inference mode
"""
model = ExecutionEncoder(latent_dim=latent_dim)
if checkpoint_path is not None:
model.load_state_dict(torch.load(checkpoint_path, map_location=device, weights_only=True))
model = model.to(device)
model.training = False # Set to inference mode
return model
|