AgentGraph / agentgraph /causal /graph_analysis.py
wu981526092's picture
🚀 Deploy AgentGraph: Complete agent monitoring and knowledge graph system
c2ea5ed
#!/usr/bin/env python3
"""
Causal Graph Analysis
This module implements the core causal graph and analysis logic for the multi-agent system.
It handles perturbation propagation and effect calculation.
"""
from collections import defaultdict
import random
import json
import copy
import numpy as np
import os
from typing import Dict, Set, List, Tuple, Any, Optional, Union
class CausalGraph:
"""
Represents the causal graph of the multi-agent system derived from the knowledge graph.
Handles perturbation propagation and effect calculation.
"""
def __init__(self, knowledge_graph: Dict):
self.kg = knowledge_graph
self.entity_ids = [entity["id"] for entity in self.kg["entities"]]
self.relation_ids = [relation["id"] for relation in self.kg["relations"]]
# Extract outcomes and build dependency structure
self.relation_outcomes = {}
self.relation_dependencies = defaultdict(set)
self._build_dependency_graph()
def _build_dependency_graph(self):
"""Build the perturbation dependency graph based on the knowledge graph structure"""
for relation in self.kg["relations"]:
rel_id = relation["id"]
# Get perturbation outcome if available (now supports values between 0 and 1)
# Check for both 'purturbation' (current misspelling) and 'perturbation' (correct spelling)
y = relation.get("purturbation", relation.get("perturbation", relation.get("defense_success_rate", None)))
if y is not None:
# Store the perturbation value (can be any float between 0 and 1)
self.relation_outcomes[rel_id] = float(y)
# Process explicit dependencies
deps = relation.get("dependencies", {})
for dep_rel in deps.get("relations", []):
self.relation_dependencies[dep_rel].add(rel_id)
for dep_ent in deps.get("entities", []):
self.relation_dependencies[dep_ent].add(rel_id)
# Self-dependency: a relation can affect its own outcome
self.relation_dependencies[rel_id].add(rel_id)
# Add source and target entity dependencies automatically
source = relation.get("source", None)
target = relation.get("target", None)
if source:
self.relation_dependencies[source].add(rel_id)
if target:
self.relation_dependencies[target].add(rel_id)
def propagate_effects(self, perturbations: Dict[str, float]) -> Dict[str, float]:
"""
Propagate perturbation effects through the dependency graph.
Args:
perturbations: Dictionary mapping relation/entity IDs to their perturbation values (0-1)
Returns:
Dictionary mapping affected relation IDs to their outcome values
"""
affected_relations = set()
# Find all relations affected by the perturbation
for p in perturbations:
if p in self.relation_dependencies:
affected_relations.update(self.relation_dependencies[p])
# Calculate outcomes for affected relations
outcomes = {}
for rel_id in affected_relations:
if rel_id in self.relation_outcomes:
# If the relation itself is perturbed, use the perturbation value directly
if rel_id in perturbations:
outcomes[rel_id] = perturbations[rel_id]
else:
# Otherwise use the stored outcome value
outcomes[rel_id] = self.relation_outcomes[rel_id]
return outcomes
def calculate_outcome(self, perturbations: Optional[Dict[str, float]] = None) -> float:
"""
Calculate the final outcome score given a set of perturbations.
Args:
perturbations: Dictionary mapping relation/entity IDs to their perturbation values (0-1)
Returns:
Aggregate outcome score
"""
if perturbations is None:
perturbations = {}
affected_outcomes = self.propagate_effects(perturbations)
if not affected_outcomes:
return 0.0
# Aggregate outcomes (simple average for now)
outcome_value = sum(affected_outcomes.values()) / len(affected_outcomes)
return outcome_value
class CausalAnalyzer:
"""
Performs causal effect analysis on the multi-agent knowledge graph system.
Calculates Average Causal Effects (ACE) and Shapley values.
"""
def __init__(self, causal_graph: CausalGraph, n_shapley_samples: int = 200):
self.causal_graph = causal_graph
self.n_shapley_samples = n_shapley_samples
self.base_outcome = self.causal_graph.calculate_outcome({})
def set_perturbation_score(self, relation_id: str, score: float) -> None:
"""
Set the perturbation score for a specific relation ID.
This allows explicitly setting scores from external sources (like database queries).
Args:
relation_id: The ID of the relation to set the score for
score: The perturbation score value (typically between 0 and 1)
"""
# Update the relation_outcomes in the causal graph
self.causal_graph.relation_outcomes[relation_id] = float(score)
def calculate_ace(self) -> Dict[str, float]:
"""
Calculate Average Causal Effect (ACE) for each entity and relation.
Returns:
Dictionary mapping IDs to their ACE scores
"""
ace_scores = {}
# Calculate ACE for relations
for rel_id in self.causal_graph.relation_ids:
if rel_id in self.causal_graph.relation_outcomes:
# Use the actual perturbation value from the outcomes
perturbed_outcome = self.causal_graph.calculate_outcome({rel_id: self.causal_graph.relation_outcomes[rel_id]})
ace_scores[rel_id] = perturbed_outcome - self.base_outcome
else:
# Default to maximum perturbation (1.0) if no value is available
perturbed_outcome = self.causal_graph.calculate_outcome({rel_id: 1.0})
ace_scores[rel_id] = perturbed_outcome - self.base_outcome
# Calculate ACE for entities
for entity_id in self.causal_graph.entity_ids:
# Default to maximum perturbation (1.0) for entities
perturbed_outcome = self.causal_graph.calculate_outcome({entity_id: 1.0})
ace_scores[entity_id] = perturbed_outcome - self.base_outcome
return ace_scores
def calculate_shapley_values(self) -> Dict[str, float]:
"""
Calculate Shapley values to fairly attribute causal effects.
Uses sampling for approximation with larger graphs.
Returns:
Dictionary mapping IDs to their Shapley values
"""
# Combine entities and relations as "players" in the Shapley calculation
all_ids = self.causal_graph.entity_ids + self.causal_graph.relation_ids
shapley_values = {id_: 0.0 for id_ in all_ids}
# Generate random permutations for Shapley approximation
for _ in range(self.n_shapley_samples):
perm = random.sample(all_ids, len(all_ids))
current_set = {} # Empty dictionary instead of empty set
current_outcome = self.base_outcome
for id_ in perm:
# Determine perturbation value to use
if id_ in self.causal_graph.relation_outcomes:
pert_value = self.causal_graph.relation_outcomes[id_]
else:
pert_value = 1.0 # Default to maximum perturbation
# Add current ID to the coalition with its perturbation value
new_set = current_set.copy()
new_set[id_] = pert_value
new_outcome = self.causal_graph.calculate_outcome(new_set)
# Calculate marginal contribution
marginal = new_outcome - current_outcome
shapley_values[id_] += marginal
# Update for next iteration
current_outcome = new_outcome
current_set = new_set
# Normalize the values
for id_ in shapley_values:
shapley_values[id_] /= self.n_shapley_samples
return shapley_values
def analyze(self) -> Tuple[Dict[str, float], Dict[str, float]]:
"""
Perform complete causal analysis.
Returns:
Tuple of (ACE scores, Shapley values)
"""
ace_scores = self.calculate_ace()
shapley_values = self.calculate_shapley_values()
return ace_scores, shapley_values
def enrich_knowledge_graph(kg: Dict, ace_scores: Dict[str, float],
shapley_values: Dict[str, float]) -> Dict:
"""
Enrich the knowledge graph with causal attribution scores.
Args:
kg: Original knowledge graph
ace_scores: Dictionary of ACE scores
shapley_values: Dictionary of Shapley values
Returns:
Enriched knowledge graph
"""
enriched_kg = copy.deepcopy(kg)
# Add scores to entities
for entity in enriched_kg["entities"]:
entity_id = entity["id"]
entity["causal_attribution"] = {
"ACE": ace_scores.get(entity_id, 0),
"Shapley": shapley_values.get(entity_id, 0)
}
# Add scores to relations
for relation in enriched_kg["relations"]:
relation_id = relation["id"]
relation["causal_attribution"] = {
"ACE": ace_scores.get(relation_id, 0),
"Shapley": shapley_values.get(relation_id, 0)
}
return enriched_kg
def generate_summary_report(ace_scores: Dict[str, float],
shapley_values: Dict[str, float],
kg: Dict) -> List[Dict]:
"""
Generate a summary report of causal attributions.
Args:
ace_scores: Dictionary of ACE scores
shapley_values: Dictionary of Shapley values
kg: Knowledge graph
Returns:
List of attribution data for each entity/relation
"""
entity_ids = [entity["id"] for entity in kg["entities"]]
report = []
for id_ in ace_scores:
if id_ in entity_ids:
type_ = "entity"
else:
type_ = "relation"
report.append({
"id": id_,
"ACE": ace_scores.get(id_, 0),
"Shapley": shapley_values.get(id_, 0),
"type": type_
})
# Sort by Shapley value to highlight most important factors
report.sort(key=lambda x: abs(x["Shapley"]), reverse=True)
return report