#!/usr/bin/env python3 """ Causal Graph Analysis This module implements the core causal graph and analysis logic for the multi-agent system. It handles perturbation propagation and effect calculation. """ from collections import defaultdict import random import json import copy import numpy as np import os from typing import Dict, Set, List, Tuple, Any, Optional, Union class CausalGraph: """ Represents the causal graph of the multi-agent system derived from the knowledge graph. Handles perturbation propagation and effect calculation. """ def __init__(self, knowledge_graph: Dict): self.kg = knowledge_graph self.entity_ids = [entity["id"] for entity in self.kg["entities"]] self.relation_ids = [relation["id"] for relation in self.kg["relations"]] # Extract outcomes and build dependency structure self.relation_outcomes = {} self.relation_dependencies = defaultdict(set) self._build_dependency_graph() def _build_dependency_graph(self): """Build the perturbation dependency graph based on the knowledge graph structure""" for relation in self.kg["relations"]: rel_id = relation["id"] # Get perturbation outcome if available (now supports values between 0 and 1) # Check for both 'purturbation' (current misspelling) and 'perturbation' (correct spelling) y = relation.get("purturbation", relation.get("perturbation", relation.get("defense_success_rate", None))) if y is not None: # Store the perturbation value (can be any float between 0 and 1) self.relation_outcomes[rel_id] = float(y) # Process explicit dependencies deps = relation.get("dependencies", {}) for dep_rel in deps.get("relations", []): self.relation_dependencies[dep_rel].add(rel_id) for dep_ent in deps.get("entities", []): self.relation_dependencies[dep_ent].add(rel_id) # Self-dependency: a relation can affect its own outcome self.relation_dependencies[rel_id].add(rel_id) # Add source and target entity dependencies automatically source = relation.get("source", None) target = relation.get("target", None) if source: self.relation_dependencies[source].add(rel_id) if target: self.relation_dependencies[target].add(rel_id) def propagate_effects(self, perturbations: Dict[str, float]) -> Dict[str, float]: """ Propagate perturbation effects through the dependency graph. Args: perturbations: Dictionary mapping relation/entity IDs to their perturbation values (0-1) Returns: Dictionary mapping affected relation IDs to their outcome values """ affected_relations = set() # Find all relations affected by the perturbation for p in perturbations: if p in self.relation_dependencies: affected_relations.update(self.relation_dependencies[p]) # Calculate outcomes for affected relations outcomes = {} for rel_id in affected_relations: if rel_id in self.relation_outcomes: # If the relation itself is perturbed, use the perturbation value directly if rel_id in perturbations: outcomes[rel_id] = perturbations[rel_id] else: # Otherwise use the stored outcome value outcomes[rel_id] = self.relation_outcomes[rel_id] return outcomes def calculate_outcome(self, perturbations: Optional[Dict[str, float]] = None) -> float: """ Calculate the final outcome score given a set of perturbations. Args: perturbations: Dictionary mapping relation/entity IDs to their perturbation values (0-1) Returns: Aggregate outcome score """ if perturbations is None: perturbations = {} affected_outcomes = self.propagate_effects(perturbations) if not affected_outcomes: return 0.0 # Aggregate outcomes (simple average for now) outcome_value = sum(affected_outcomes.values()) / len(affected_outcomes) return outcome_value class CausalAnalyzer: """ Performs causal effect analysis on the multi-agent knowledge graph system. Calculates Average Causal Effects (ACE) and Shapley values. """ def __init__(self, causal_graph: CausalGraph, n_shapley_samples: int = 200): self.causal_graph = causal_graph self.n_shapley_samples = n_shapley_samples self.base_outcome = self.causal_graph.calculate_outcome({}) def set_perturbation_score(self, relation_id: str, score: float) -> None: """ Set the perturbation score for a specific relation ID. This allows explicitly setting scores from external sources (like database queries). Args: relation_id: The ID of the relation to set the score for score: The perturbation score value (typically between 0 and 1) """ # Update the relation_outcomes in the causal graph self.causal_graph.relation_outcomes[relation_id] = float(score) def calculate_ace(self) -> Dict[str, float]: """ Calculate Average Causal Effect (ACE) for each entity and relation. Returns: Dictionary mapping IDs to their ACE scores """ ace_scores = {} # Calculate ACE for relations for rel_id in self.causal_graph.relation_ids: if rel_id in self.causal_graph.relation_outcomes: # Use the actual perturbation value from the outcomes perturbed_outcome = self.causal_graph.calculate_outcome({rel_id: self.causal_graph.relation_outcomes[rel_id]}) ace_scores[rel_id] = perturbed_outcome - self.base_outcome else: # Default to maximum perturbation (1.0) if no value is available perturbed_outcome = self.causal_graph.calculate_outcome({rel_id: 1.0}) ace_scores[rel_id] = perturbed_outcome - self.base_outcome # Calculate ACE for entities for entity_id in self.causal_graph.entity_ids: # Default to maximum perturbation (1.0) for entities perturbed_outcome = self.causal_graph.calculate_outcome({entity_id: 1.0}) ace_scores[entity_id] = perturbed_outcome - self.base_outcome return ace_scores def calculate_shapley_values(self) -> Dict[str, float]: """ Calculate Shapley values to fairly attribute causal effects. Uses sampling for approximation with larger graphs. Returns: Dictionary mapping IDs to their Shapley values """ # Combine entities and relations as "players" in the Shapley calculation all_ids = self.causal_graph.entity_ids + self.causal_graph.relation_ids shapley_values = {id_: 0.0 for id_ in all_ids} # Generate random permutations for Shapley approximation for _ in range(self.n_shapley_samples): perm = random.sample(all_ids, len(all_ids)) current_set = {} # Empty dictionary instead of empty set current_outcome = self.base_outcome for id_ in perm: # Determine perturbation value to use if id_ in self.causal_graph.relation_outcomes: pert_value = self.causal_graph.relation_outcomes[id_] else: pert_value = 1.0 # Default to maximum perturbation # Add current ID to the coalition with its perturbation value new_set = current_set.copy() new_set[id_] = pert_value new_outcome = self.causal_graph.calculate_outcome(new_set) # Calculate marginal contribution marginal = new_outcome - current_outcome shapley_values[id_] += marginal # Update for next iteration current_outcome = new_outcome current_set = new_set # Normalize the values for id_ in shapley_values: shapley_values[id_] /= self.n_shapley_samples return shapley_values def analyze(self) -> Tuple[Dict[str, float], Dict[str, float]]: """ Perform complete causal analysis. Returns: Tuple of (ACE scores, Shapley values) """ ace_scores = self.calculate_ace() shapley_values = self.calculate_shapley_values() return ace_scores, shapley_values def enrich_knowledge_graph(kg: Dict, ace_scores: Dict[str, float], shapley_values: Dict[str, float]) -> Dict: """ Enrich the knowledge graph with causal attribution scores. Args: kg: Original knowledge graph ace_scores: Dictionary of ACE scores shapley_values: Dictionary of Shapley values Returns: Enriched knowledge graph """ enriched_kg = copy.deepcopy(kg) # Add scores to entities for entity in enriched_kg["entities"]: entity_id = entity["id"] entity["causal_attribution"] = { "ACE": ace_scores.get(entity_id, 0), "Shapley": shapley_values.get(entity_id, 0) } # Add scores to relations for relation in enriched_kg["relations"]: relation_id = relation["id"] relation["causal_attribution"] = { "ACE": ace_scores.get(relation_id, 0), "Shapley": shapley_values.get(relation_id, 0) } return enriched_kg def generate_summary_report(ace_scores: Dict[str, float], shapley_values: Dict[str, float], kg: Dict) -> List[Dict]: """ Generate a summary report of causal attributions. Args: ace_scores: Dictionary of ACE scores shapley_values: Dictionary of Shapley values kg: Knowledge graph Returns: List of attribution data for each entity/relation """ entity_ids = [entity["id"] for entity in kg["entities"]] report = [] for id_ in ace_scores: if id_ in entity_ids: type_ = "entity" else: type_ = "relation" report.append({ "id": id_, "ACE": ace_scores.get(id_, 0), "Shapley": shapley_values.get(id_, 0), "type": type_ }) # Sort by Shapley value to highlight most important factors report.sort(key=lambda x: abs(x["Shapley"]), reverse=True) return report