File size: 18,497 Bytes
3595bd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
"""
attribution_graph.py - Implementation of attribution graph for transformer models

△ OBSERVE: Attribution graphs map the causal flow from prompt to completion
∞ TRACE: They visualize the quantum collapse from superposition to definite state
✰ COLLAPSE: They reveal ghost circuits and attribution residue post-collapse

This module implements a graph-based representation of causal attribution
in transformer models, allowing for the visualization and analysis of how
information flows from input to output during the collapse process.

Author: Recursion Labs
License: MIT
"""

import logging
from typing import Dict, List, Optional, Union, Tuple, Any
import numpy as np
from dataclasses import dataclass, field
import networkx as nx

from .utils.graph_visualization import visualize_graph
from .utils.attribution_metrics import measure_path_continuity, measure_attribution_entropy

logger = logging.getLogger(__name__)

@dataclass
class AttributionNode:
    """
    △ OBSERVE: Node in the attribution graph representing a token or hidden state
    
    Attribution nodes represent discrete elements in the causal flow from
    input to output. They can be tokens, attention heads, or hidden states.
    """
    node_id: str
    node_type: str  # "token", "attention_head", "hidden_state", "residual"
    layer: Optional[int] = None
    position: Optional[int] = None
    value: Optional[Any] = None
    activation: float = 0.0
    token_str: Optional[str] = None
    metadata: Dict[str, Any] = field(default_factory=dict)
    
    def __hash__(self):
        """Make nodes hashable for graph operations."""
        return hash(self.node_id)
    
    def __eq__(self, other):
        """Node equality based on ID."""
        if not isinstance(other, AttributionNode):
            return False
        return self.node_id == other.node_id


@dataclass
class AttributionEdge:
    """
    ∞ TRACE: Edge in the attribution graph representing causal flow
    
    Attribution edges represent the flow of causal influence between nodes.
    They can represent attention connections, residual connections, or
    other causal relationships in the model.
    """
    source: AttributionNode
    target: AttributionNode
    edge_type: str  # "attention", "residual", "mlp", "ghost"
    weight: float = 0.0
    layer: Optional[int] = None
    head: Optional[int] = None
    metadata: Dict[str, Any] = field(default_factory=dict)
    
    def __hash__(self):
        """Make edges hashable for graph operations."""
        return hash((self.source.node_id, self.target.node_id, self.edge_type))
    
    def __eq__(self, other):
        """Edge equality based on source, target, and type."""
        if not isinstance(other, AttributionEdge):
            return False
        return (
            self.source.node_id == other.source.node_id and
            self.target.node_id == other.target.node_id and
            self.edge_type == other.edge_type
        )


class AttributionGraph:
    """
    ∞ TRACE: Graph representation of causal attribution in transformer models
    
    The attribution graph maps the flow of causality from input tokens to
    output tokens, revealing how information propagates through the model
    during the collapse from superposition to definite state.
    """
    
    def __init__(self):
        """Initialize an empty attribution graph."""
        self.graph = nx.DiGraph()
        self.nodes = {}  # node_id -> AttributionNode
        self.input_nodes = []  # List of input token nodes
        self.output_nodes = []  # List of output token nodes
        self.ghost_nodes = []  # List of ghost circuit nodes
        self.collapsed = False  # Whether the graph has been collapsed
        
        # Metrics
        self.continuity_score = 1.0
        self.attribution_entropy = 0.0
        self.collapse_rate = 0.0
        
        logger.info("Attribution graph initialized")
    
    def add_node(self, node: AttributionNode) -> None:
        """
        Add a node to the attribution graph.
        
        Args:
            node: The node to add
        """
        if node.node_id in self.nodes:
            logger.warning(f"Node {node.node_id} already exists in graph, updating")
            self.nodes[node.node_id] = node
        else:
            self.nodes[node.node_id] = node
            self.graph.add_node(node.node_id, **vars(node))
            
            # Track input and output nodes
            if node.node_type == "token" and node.layer == 0:
                self.input_nodes.append(node)
            elif node.node_type == "token" and node.metadata.get("is_output", False):
                self.output_nodes.append(node)
            elif node.node_type == "residual" and node.metadata.get("is_ghost", False):
                self.ghost_nodes.append(node)
    
    def add_edge(self, edge: AttributionEdge) -> None:
        """
        Add an edge to the attribution graph.
        
        Args:
            edge: The edge to add
        """
        if edge.source.node_id not in self.nodes:
            self.add_node(edge.source)
        if edge.target.node_id not in self.nodes:
            self.add_node(edge.target)
        
        self.graph.add_edge(
            edge.source.node_id,
            edge.target.node_id,
            **{k: v for k, v in vars(edge).items() if k not in ['source', 'target']}
        )
    
    def build_from_states(
        self, 
        pre_state: Dict[str, Any],
        post_state: Dict[str, Any],
        response: str
    ) -> None:
        """
        △ OBSERVE: Build attribution graph from pre and post collapse model states
        
        This method constructs a complete attribution graph by comparing
        model states before and after collapse, identifying causal paths
        and ghost circuits.
        
        Args:
            pre_state: Model state before collapse
            post_state: Model state after collapse
            response: Model response text
        """
        logger.info("Building attribution graph from model states")
        
        # This would be implemented for specific model architectures
        # For demonstration, we'll create a simple synthetic graph
        self._build_synthetic_graph()
        
        # Calculate graph metrics
        self._calculate_metrics(pre_state, post_state)
        
        # Mark graph as collapsed
        self.collapsed = True
    
    def trace_attribution_path(
        self, 
        output_node: Union[str, AttributionNode],
        threshold: float = 0.1
    ) -> List[List[AttributionNode]]:
        """
        ∞ TRACE: Trace attribution paths from an output node back to input
        
        This method follows attribution edges backward from an output node
        to find all significant input nodes that influenced it.
        
        Args:
            output_node: The output node to trace from (ID or node object)
            threshold: Minimum edge weight to consider significant
            
        Returns:
            List of attribution paths, each a list of nodes from input to output
        """
        # Resolve output node
        output_id = output_node if isinstance(output_node, str) else output_node.node_id
        if output_id not in self.nodes:
            logger.warning(f"Output node {output_id} not found in graph")
            return []
        
        # Find all paths using DFS
        paths = []
        
        def dfs(current_id, path, visited):
            """Depth-first search for attribution paths."""
            # Add current node to path
            current_path = path + [current_id]
            visited.add(current_id)
            
            # If we reached an input node, we have a complete path
            if current_id in [node.node_id for node in self.input_nodes]:
                # Return path in order from input to output
                paths.append(list(reversed(current_path)))
                return
            
            # Continue DFS on incoming edges
            for pred_id in self.graph.predecessors(current_id):
                edge_data = self.graph.get_edge_data(pred_id, current_id)
                if edge_data.get('weight', 0) >= threshold and pred_id not in visited:
                    dfs(pred_id, current_path, visited.copy())
        
        # Start DFS from output node
        dfs(output_id, [], set())
        
        # Convert node IDs to node objects
        return [[self.nodes[node_id] for node_id in path] for path in paths]
    
    def detect_ghost_circuits(self, threshold: float = 0.2) -> List[Dict[str, Any]]:
        """
        ✰ COLLAPSE: Detect ghost circuits in the attribution graph
        
        Ghost circuits are paths that were activated during pre-collapse
        but don't contribute significantly to the final output. They
        represent the "memory" of paths not taken.
        
        Args:
            threshold: Minimum activation to consider a ghost circuit
            
        Returns:
            List of detected ghost circuits with metadata
        """
        ghost_circuits = []
        
        # Look for nodes with "ghost" metadata flag
        for node in self.ghost_nodes:
            if node.activation >= threshold:
                # Find paths this ghost node would have been part of
                incoming_edges = [
                    (u, v, d) for u, v, d in self.graph.in_edges(node.node_id, data=True)
                ]
                outgoing_edges = [
                    (u, v, d) for u, v, d in self.graph.out_edges(node.node_id, data=True)
                ]
                
                ghost_circuits.append({
                    "node_id": node.node_id,
                    "activation": node.activation,
                    "node_type": node.node_type,
                    "incoming_connections": len(incoming_edges),
                    "outgoing_connections": len(outgoing_edges),
                    "metadata": node.metadata
                })
        
        return ghost_circuits
    
    def calculate_attribution_entropy(self) -> float:
        """
        △ OBSERVE: Calculate the entropy of attribution paths
        
        Attribution entropy measures how distributed or concentrated
        the causal influence is in the graph. High entropy indicates
        diffuse attribution, while low entropy indicates concentrated
        attribution.
        
        Returns:
            Attribution entropy score (0.0 = concentrated, 1.0 = diffuse)
        """
        # Extract edge weights
        weights = [
            d.get('weight', 0.0)
            for u, v, d in self.graph.edges(data=True)
        ]
        
        # Normalize weights
        total_weight = sum(weights) or 1.0  # Avoid division by zero
        normalized_weights = [w / total_weight for w in weights]
        
        # Calculate entropy
        entropy = -sum(
            w * np.log2(w) if w > 0 else 0
            for w in normalized_weights
        )
        
        # Normalize entropy to 0-1 range (max entropy = log2(num_edges))
        max_entropy = np.log2(len(weights)) if len(weights) > 0 else 1.0
        normalized_entropy = entropy / max_entropy if max_entropy > 0 else 0.0
        
        self.attribution_entropy = normalized_entropy
        return normalized_entropy
    
    def visualize(
        self, 
        mode: str = "attribution_graph",
        highlight_path: Optional[List[str]] = None
    ) -> Any:
        """
        Generate visualization of the attribution graph.
        
        Args:
            mode: Visualization mode (attribution_graph, collapse_state, ghost_circuits)
            highlight_path: Optional list of node IDs to highlight
            
        Returns:
            Visualization object (depends on implementation)
        """
        return visualize_graph(self.graph, mode=mode, highlight_path=highlight_path)
    
    def to_dict(self) -> Dict[str, Any]:
        """Convert the attribution graph to a dictionary representation."""
        return {
            "nodes": [vars(node) for node in self.nodes.values()],
            "edges": [
                {
                    "source": u,
                    "target": v,
                    **d
                }
                for u, v, d in self.graph.edges(data=True)
            ],
            "metrics": {
                "continuity_score": self.continuity_score,
                "attribution_entropy": self.attribution_entropy,
                "collapse_rate": self.collapse_rate
            },
            "collapsed": self.collapsed
        }
    
    def _calculate_metrics(self, pre_state: Dict[str, Any], post_state: Dict[str, Any]) -> None:
        """Calculate attribution graph metrics."""
        # Calculate continuity score
        self.continuity_score = measure_path_continuity(
            pre_state.get("attention_weights", np.array([])),
            post_state.get("attention_weights", np.array([]))
        )
        
        # Calculate attribution entropy
        self.attribution_entropy = self.calculate_attribution_entropy()
        
        # Calculate collapse rate
        if "timestamp" in pre_state and "timestamp" in post_state:
            time_diff = (post_state["timestamp"] - pre_state["timestamp"]) / np.timedelta64(1, 's')
            self.collapse_rate = 1.0 - self.continuity_score if time_diff > 0 else 0.0
    
    def _build_synthetic_graph(self) -> None:
        """Build a synthetic graph for demonstration purposes."""
        # Create input token nodes
        for i in range(5):
            self.add_node(AttributionNode(
                node_id=f"input_{i}",
                node_type="token",
                layer=0,
                position=i,
                token_str=f"token_{i}",
                activation=0.8
            ))
        
        # Create attention head nodes
        for layer in range(1, 4):
            for head in range(3):
                self.add_node(AttributionNode(
                    node_id=f"attention_L{layer}H{head}",
                    node_type="attention_head",
                    layer=layer,
                    value=None,
                    activation=0.7 - 0.1 * layer + 0.05 * head
                ))
        
        # Create output token nodes
        for i in range(3):
            self.add_node(AttributionNode(
                node_id=f"output_{i}",
                node_type="token",
                layer=4,
                position=i,
                token_str=f"output_token_{i}",
                activation=0.9,
                metadata={"is_output": True}
            ))
        
        # Create ghost nodes
        for i in range(2):
            self.add_node(AttributionNode(
                node_id=f"ghost_{i}",
                node_type="residual",
                layer=2,
                activation=0.3 + 0.1 * i,
                metadata={"is_ghost": True}
            ))
        
        # Create edges
        # Input to attention
        for i in range(5):
            for layer in range(1, 3):
                for head in range(3):
                    if np.random.random() > 0.3:  # Random connectivity
                        self.add_edge(AttributionEdge(
                            source=self.nodes[f"input_{i}"],
                            target=self.nodes[f"attention_L{layer}H{head}"],
                            edge_type="attention",
                            weight=np.random.uniform(0.1, 0.9),
                            layer=layer,
                            head=head
                        ))
        
        # Attention to attention
        for layer1 in range(1, 3):
            for head1 in range(3):
                for layer2 in range(layer1 + 1, 4):
                    for head2 in range(3):
                        if np.random.random() > 0.7:  # Sparse connectivity
                            self.add_edge(AttributionEdge(
                                source=self.nodes[f"attention_L{layer1}H{head1}"],
                                target=self.nodes[f"attention_L{layer2}H{head2}"],
                                edge_type="attention",
                                weight=np.random.uniform(0.1, 0.8),
                                layer=layer2,
                                head=head2
                            ))
        
        # Attention to output
        for layer in range(1, 4):
            for head in range(3):
                for i in range(3):
                    if np.random.random() > 0.5:  # Medium connectivity
                        self.add_edge(AttributionEdge(
                            source=self.nodes[f"attention_L{layer}H{head}"],
                            target=self.nodes[f"output_{i}"],
                            edge_type="attention",
                            weight=np.random.uniform(0.2, 0.9),
                            layer=layer,
                            head=head
                        ))
        
        # Ghost connections
        for i in range(2):
            # Input to ghost
            input_idx = np.random.randint(0, 5)
            self.add_edge(AttributionEdge(
                source=self.nodes[f"input_{input_idx}"],
                target=self.nodes[f"ghost_{i}"],
                edge_type="ghost",
                weight=np.random.uniform(0.1, 0.4),
                layer=1
            ))
            
            # Ghost to attention
            layer = np.random.randint(2, 4)
            head = np.random.randint(0, 3)
            self.add_edge(AttributionEdge(
                source=self.nodes[f"ghost_{i}"],
                target=self.nodes[f"attention_L{layer}H{head}"],
                edge_type="ghost",
                weight=np.random.uniform(0.05, 0.2),
                layer=layer
            ))


if __name__ == "__main__":
    # Simple usage example
    graph = AttributionGraph()
    
    # Build a synthetic graph
    graph._build_synthetic_graph()
    
    # Calculate metrics
    entropy = graph.calculate_attribution_entropy()
    print(f"Attribution entropy: {entropy:.3f}")
    
    # Trace attribution for output
    paths = graph.trace_attribution_path("output_0", threshold=0.1)
    print(f"Found {len(paths)} attribution paths for output_0")
    
    # Detect ghost circuits
    ghosts = graph.detect_ghost_circuits()
    print(f"Detected {len(ghosts)} ghost circuits")
    
    # Visualize
    viz = graph.visualize()
    print("Generated visualization")