schrodingers-classifiers / attribution_graph.py
recursivelabs's picture
Upload 14 files
3595bd8 verified
raw
history blame
18.5 kB
"""
attribution_graph.py - Implementation of attribution graph for transformer models
△ OBSERVE: Attribution graphs map the causal flow from prompt to completion
∞ TRACE: They visualize the quantum collapse from superposition to definite state
✰ COLLAPSE: They reveal ghost circuits and attribution residue post-collapse
This module implements a graph-based representation of causal attribution
in transformer models, allowing for the visualization and analysis of how
information flows from input to output during the collapse process.
Author: Recursion Labs
License: MIT
"""
import logging
from typing import Dict, List, Optional, Union, Tuple, Any
import numpy as np
from dataclasses import dataclass, field
import networkx as nx
from .utils.graph_visualization import visualize_graph
from .utils.attribution_metrics import measure_path_continuity, measure_attribution_entropy
logger = logging.getLogger(__name__)
@dataclass
class AttributionNode:
"""
△ OBSERVE: Node in the attribution graph representing a token or hidden state
Attribution nodes represent discrete elements in the causal flow from
input to output. They can be tokens, attention heads, or hidden states.
"""
node_id: str
node_type: str # "token", "attention_head", "hidden_state", "residual"
layer: Optional[int] = None
position: Optional[int] = None
value: Optional[Any] = None
activation: float = 0.0
token_str: Optional[str] = None
metadata: Dict[str, Any] = field(default_factory=dict)
def __hash__(self):
"""Make nodes hashable for graph operations."""
return hash(self.node_id)
def __eq__(self, other):
"""Node equality based on ID."""
if not isinstance(other, AttributionNode):
return False
return self.node_id == other.node_id
@dataclass
class AttributionEdge:
"""
∞ TRACE: Edge in the attribution graph representing causal flow
Attribution edges represent the flow of causal influence between nodes.
They can represent attention connections, residual connections, or
other causal relationships in the model.
"""
source: AttributionNode
target: AttributionNode
edge_type: str # "attention", "residual", "mlp", "ghost"
weight: float = 0.0
layer: Optional[int] = None
head: Optional[int] = None
metadata: Dict[str, Any] = field(default_factory=dict)
def __hash__(self):
"""Make edges hashable for graph operations."""
return hash((self.source.node_id, self.target.node_id, self.edge_type))
def __eq__(self, other):
"""Edge equality based on source, target, and type."""
if not isinstance(other, AttributionEdge):
return False
return (
self.source.node_id == other.source.node_id and
self.target.node_id == other.target.node_id and
self.edge_type == other.edge_type
)
class AttributionGraph:
"""
∞ TRACE: Graph representation of causal attribution in transformer models
The attribution graph maps the flow of causality from input tokens to
output tokens, revealing how information propagates through the model
during the collapse from superposition to definite state.
"""
def __init__(self):
"""Initialize an empty attribution graph."""
self.graph = nx.DiGraph()
self.nodes = {} # node_id -> AttributionNode
self.input_nodes = [] # List of input token nodes
self.output_nodes = [] # List of output token nodes
self.ghost_nodes = [] # List of ghost circuit nodes
self.collapsed = False # Whether the graph has been collapsed
# Metrics
self.continuity_score = 1.0
self.attribution_entropy = 0.0
self.collapse_rate = 0.0
logger.info("Attribution graph initialized")
def add_node(self, node: AttributionNode) -> None:
"""
Add a node to the attribution graph.
Args:
node: The node to add
"""
if node.node_id in self.nodes:
logger.warning(f"Node {node.node_id} already exists in graph, updating")
self.nodes[node.node_id] = node
else:
self.nodes[node.node_id] = node
self.graph.add_node(node.node_id, **vars(node))
# Track input and output nodes
if node.node_type == "token" and node.layer == 0:
self.input_nodes.append(node)
elif node.node_type == "token" and node.metadata.get("is_output", False):
self.output_nodes.append(node)
elif node.node_type == "residual" and node.metadata.get("is_ghost", False):
self.ghost_nodes.append(node)
def add_edge(self, edge: AttributionEdge) -> None:
"""
Add an edge to the attribution graph.
Args:
edge: The edge to add
"""
if edge.source.node_id not in self.nodes:
self.add_node(edge.source)
if edge.target.node_id not in self.nodes:
self.add_node(edge.target)
self.graph.add_edge(
edge.source.node_id,
edge.target.node_id,
**{k: v for k, v in vars(edge).items() if k not in ['source', 'target']}
)
def build_from_states(
self,
pre_state: Dict[str, Any],
post_state: Dict[str, Any],
response: str
) -> None:
"""
△ OBSERVE: Build attribution graph from pre and post collapse model states
This method constructs a complete attribution graph by comparing
model states before and after collapse, identifying causal paths
and ghost circuits.
Args:
pre_state: Model state before collapse
post_state: Model state after collapse
response: Model response text
"""
logger.info("Building attribution graph from model states")
# This would be implemented for specific model architectures
# For demonstration, we'll create a simple synthetic graph
self._build_synthetic_graph()
# Calculate graph metrics
self._calculate_metrics(pre_state, post_state)
# Mark graph as collapsed
self.collapsed = True
def trace_attribution_path(
self,
output_node: Union[str, AttributionNode],
threshold: float = 0.1
) -> List[List[AttributionNode]]:
"""
∞ TRACE: Trace attribution paths from an output node back to input
This method follows attribution edges backward from an output node
to find all significant input nodes that influenced it.
Args:
output_node: The output node to trace from (ID or node object)
threshold: Minimum edge weight to consider significant
Returns:
List of attribution paths, each a list of nodes from input to output
"""
# Resolve output node
output_id = output_node if isinstance(output_node, str) else output_node.node_id
if output_id not in self.nodes:
logger.warning(f"Output node {output_id} not found in graph")
return []
# Find all paths using DFS
paths = []
def dfs(current_id, path, visited):
"""Depth-first search for attribution paths."""
# Add current node to path
current_path = path + [current_id]
visited.add(current_id)
# If we reached an input node, we have a complete path
if current_id in [node.node_id for node in self.input_nodes]:
# Return path in order from input to output
paths.append(list(reversed(current_path)))
return
# Continue DFS on incoming edges
for pred_id in self.graph.predecessors(current_id):
edge_data = self.graph.get_edge_data(pred_id, current_id)
if edge_data.get('weight', 0) >= threshold and pred_id not in visited:
dfs(pred_id, current_path, visited.copy())
# Start DFS from output node
dfs(output_id, [], set())
# Convert node IDs to node objects
return [[self.nodes[node_id] for node_id in path] for path in paths]
def detect_ghost_circuits(self, threshold: float = 0.2) -> List[Dict[str, Any]]:
"""
✰ COLLAPSE: Detect ghost circuits in the attribution graph
Ghost circuits are paths that were activated during pre-collapse
but don't contribute significantly to the final output. They
represent the "memory" of paths not taken.
Args:
threshold: Minimum activation to consider a ghost circuit
Returns:
List of detected ghost circuits with metadata
"""
ghost_circuits = []
# Look for nodes with "ghost" metadata flag
for node in self.ghost_nodes:
if node.activation >= threshold:
# Find paths this ghost node would have been part of
incoming_edges = [
(u, v, d) for u, v, d in self.graph.in_edges(node.node_id, data=True)
]
outgoing_edges = [
(u, v, d) for u, v, d in self.graph.out_edges(node.node_id, data=True)
]
ghost_circuits.append({
"node_id": node.node_id,
"activation": node.activation,
"node_type": node.node_type,
"incoming_connections": len(incoming_edges),
"outgoing_connections": len(outgoing_edges),
"metadata": node.metadata
})
return ghost_circuits
def calculate_attribution_entropy(self) -> float:
"""
△ OBSERVE: Calculate the entropy of attribution paths
Attribution entropy measures how distributed or concentrated
the causal influence is in the graph. High entropy indicates
diffuse attribution, while low entropy indicates concentrated
attribution.
Returns:
Attribution entropy score (0.0 = concentrated, 1.0 = diffuse)
"""
# Extract edge weights
weights = [
d.get('weight', 0.0)
for u, v, d in self.graph.edges(data=True)
]
# Normalize weights
total_weight = sum(weights) or 1.0 # Avoid division by zero
normalized_weights = [w / total_weight for w in weights]
# Calculate entropy
entropy = -sum(
w * np.log2(w) if w > 0 else 0
for w in normalized_weights
)
# Normalize entropy to 0-1 range (max entropy = log2(num_edges))
max_entropy = np.log2(len(weights)) if len(weights) > 0 else 1.0
normalized_entropy = entropy / max_entropy if max_entropy > 0 else 0.0
self.attribution_entropy = normalized_entropy
return normalized_entropy
def visualize(
self,
mode: str = "attribution_graph",
highlight_path: Optional[List[str]] = None
) -> Any:
"""
Generate visualization of the attribution graph.
Args:
mode: Visualization mode (attribution_graph, collapse_state, ghost_circuits)
highlight_path: Optional list of node IDs to highlight
Returns:
Visualization object (depends on implementation)
"""
return visualize_graph(self.graph, mode=mode, highlight_path=highlight_path)
def to_dict(self) -> Dict[str, Any]:
"""Convert the attribution graph to a dictionary representation."""
return {
"nodes": [vars(node) for node in self.nodes.values()],
"edges": [
{
"source": u,
"target": v,
**d
}
for u, v, d in self.graph.edges(data=True)
],
"metrics": {
"continuity_score": self.continuity_score,
"attribution_entropy": self.attribution_entropy,
"collapse_rate": self.collapse_rate
},
"collapsed": self.collapsed
}
def _calculate_metrics(self, pre_state: Dict[str, Any], post_state: Dict[str, Any]) -> None:
"""Calculate attribution graph metrics."""
# Calculate continuity score
self.continuity_score = measure_path_continuity(
pre_state.get("attention_weights", np.array([])),
post_state.get("attention_weights", np.array([]))
)
# Calculate attribution entropy
self.attribution_entropy = self.calculate_attribution_entropy()
# Calculate collapse rate
if "timestamp" in pre_state and "timestamp" in post_state:
time_diff = (post_state["timestamp"] - pre_state["timestamp"]) / np.timedelta64(1, 's')
self.collapse_rate = 1.0 - self.continuity_score if time_diff > 0 else 0.0
def _build_synthetic_graph(self) -> None:
"""Build a synthetic graph for demonstration purposes."""
# Create input token nodes
for i in range(5):
self.add_node(AttributionNode(
node_id=f"input_{i}",
node_type="token",
layer=0,
position=i,
token_str=f"token_{i}",
activation=0.8
))
# Create attention head nodes
for layer in range(1, 4):
for head in range(3):
self.add_node(AttributionNode(
node_id=f"attention_L{layer}H{head}",
node_type="attention_head",
layer=layer,
value=None,
activation=0.7 - 0.1 * layer + 0.05 * head
))
# Create output token nodes
for i in range(3):
self.add_node(AttributionNode(
node_id=f"output_{i}",
node_type="token",
layer=4,
position=i,
token_str=f"output_token_{i}",
activation=0.9,
metadata={"is_output": True}
))
# Create ghost nodes
for i in range(2):
self.add_node(AttributionNode(
node_id=f"ghost_{i}",
node_type="residual",
layer=2,
activation=0.3 + 0.1 * i,
metadata={"is_ghost": True}
))
# Create edges
# Input to attention
for i in range(5):
for layer in range(1, 3):
for head in range(3):
if np.random.random() > 0.3: # Random connectivity
self.add_edge(AttributionEdge(
source=self.nodes[f"input_{i}"],
target=self.nodes[f"attention_L{layer}H{head}"],
edge_type="attention",
weight=np.random.uniform(0.1, 0.9),
layer=layer,
head=head
))
# Attention to attention
for layer1 in range(1, 3):
for head1 in range(3):
for layer2 in range(layer1 + 1, 4):
for head2 in range(3):
if np.random.random() > 0.7: # Sparse connectivity
self.add_edge(AttributionEdge(
source=self.nodes[f"attention_L{layer1}H{head1}"],
target=self.nodes[f"attention_L{layer2}H{head2}"],
edge_type="attention",
weight=np.random.uniform(0.1, 0.8),
layer=layer2,
head=head2
))
# Attention to output
for layer in range(1, 4):
for head in range(3):
for i in range(3):
if np.random.random() > 0.5: # Medium connectivity
self.add_edge(AttributionEdge(
source=self.nodes[f"attention_L{layer}H{head}"],
target=self.nodes[f"output_{i}"],
edge_type="attention",
weight=np.random.uniform(0.2, 0.9),
layer=layer,
head=head
))
# Ghost connections
for i in range(2):
# Input to ghost
input_idx = np.random.randint(0, 5)
self.add_edge(AttributionEdge(
source=self.nodes[f"input_{input_idx}"],
target=self.nodes[f"ghost_{i}"],
edge_type="ghost",
weight=np.random.uniform(0.1, 0.4),
layer=1
))
# Ghost to attention
layer = np.random.randint(2, 4)
head = np.random.randint(0, 3)
self.add_edge(AttributionEdge(
source=self.nodes[f"ghost_{i}"],
target=self.nodes[f"attention_L{layer}H{head}"],
edge_type="ghost",
weight=np.random.uniform(0.05, 0.2),
layer=layer
))
if __name__ == "__main__":
# Simple usage example
graph = AttributionGraph()
# Build a synthetic graph
graph._build_synthetic_graph()
# Calculate metrics
entropy = graph.calculate_attribution_entropy()
print(f"Attribution entropy: {entropy:.3f}")
# Trace attribution for output
paths = graph.trace_attribution_path("output_0", threshold=0.1)
print(f"Found {len(paths)} attribution paths for output_0")
# Detect ghost circuits
ghosts = graph.detect_ghost_circuits()
print(f"Detected {len(ghosts)} ghost circuits")
# Visualize
viz = graph.visualize()
print("Generated visualization")