Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Causal Graph Analysis | |
| This module implements the core causal graph and analysis logic for the multi-agent system. | |
| It handles perturbation propagation and effect calculation. | |
| """ | |
| from collections import defaultdict | |
| import random | |
| import json | |
| import copy | |
| import numpy as np | |
| import os | |
| from typing import Dict, Set, List, Tuple, Any, Optional, Union | |
| class CausalGraph: | |
| """ | |
| Represents the causal graph of the multi-agent system derived from the knowledge graph. | |
| Handles perturbation propagation and effect calculation. | |
| """ | |
| def __init__(self, knowledge_graph: Dict): | |
| self.kg = knowledge_graph | |
| self.entity_ids = [entity["id"] for entity in self.kg["entities"]] | |
| self.relation_ids = [relation["id"] for relation in self.kg["relations"]] | |
| # Extract outcomes and build dependency structure | |
| self.relation_outcomes = {} | |
| self.relation_dependencies = defaultdict(set) | |
| self._build_dependency_graph() | |
| def _build_dependency_graph(self): | |
| """Build the perturbation dependency graph based on the knowledge graph structure""" | |
| for relation in self.kg["relations"]: | |
| rel_id = relation["id"] | |
| # Get perturbation outcome if available (now supports values between 0 and 1) | |
| # Check for both 'purturbation' (current misspelling) and 'perturbation' (correct spelling) | |
| y = relation.get("purturbation", relation.get("perturbation", relation.get("defense_success_rate", None))) | |
| if y is not None: | |
| # Store the perturbation value (can be any float between 0 and 1) | |
| self.relation_outcomes[rel_id] = float(y) | |
| # Process explicit dependencies | |
| deps = relation.get("dependencies", {}) | |
| for dep_rel in deps.get("relations", []): | |
| self.relation_dependencies[dep_rel].add(rel_id) | |
| for dep_ent in deps.get("entities", []): | |
| self.relation_dependencies[dep_ent].add(rel_id) | |
| # Self-dependency: a relation can affect its own outcome | |
| self.relation_dependencies[rel_id].add(rel_id) | |
| # Add source and target entity dependencies automatically | |
| source = relation.get("source", None) | |
| target = relation.get("target", None) | |
| if source: | |
| self.relation_dependencies[source].add(rel_id) | |
| if target: | |
| self.relation_dependencies[target].add(rel_id) | |
| def propagate_effects(self, perturbations: Dict[str, float]) -> Dict[str, float]: | |
| """ | |
| Propagate perturbation effects through the dependency graph. | |
| Args: | |
| perturbations: Dictionary mapping relation/entity IDs to their perturbation values (0-1) | |
| Returns: | |
| Dictionary mapping affected relation IDs to their outcome values | |
| """ | |
| affected_relations = set() | |
| # Find all relations affected by the perturbation | |
| for p in perturbations: | |
| if p in self.relation_dependencies: | |
| affected_relations.update(self.relation_dependencies[p]) | |
| # Calculate outcomes for affected relations | |
| outcomes = {} | |
| for rel_id in affected_relations: | |
| if rel_id in self.relation_outcomes: | |
| # If the relation itself is perturbed, use the perturbation value directly | |
| if rel_id in perturbations: | |
| outcomes[rel_id] = perturbations[rel_id] | |
| else: | |
| # Otherwise use the stored outcome value | |
| outcomes[rel_id] = self.relation_outcomes[rel_id] | |
| return outcomes | |
| def calculate_outcome(self, perturbations: Optional[Dict[str, float]] = None) -> float: | |
| """ | |
| Calculate the final outcome score given a set of perturbations. | |
| Args: | |
| perturbations: Dictionary mapping relation/entity IDs to their perturbation values (0-1) | |
| Returns: | |
| Aggregate outcome score | |
| """ | |
| if perturbations is None: | |
| perturbations = {} | |
| affected_outcomes = self.propagate_effects(perturbations) | |
| if not affected_outcomes: | |
| return 0.0 | |
| # Aggregate outcomes (simple average for now) | |
| outcome_value = sum(affected_outcomes.values()) / len(affected_outcomes) | |
| return outcome_value | |
| class CausalAnalyzer: | |
| """ | |
| Performs causal effect analysis on the multi-agent knowledge graph system. | |
| Calculates Average Causal Effects (ACE) and Shapley values. | |
| """ | |
| def __init__(self, causal_graph: CausalGraph, n_shapley_samples: int = 200): | |
| self.causal_graph = causal_graph | |
| self.n_shapley_samples = n_shapley_samples | |
| self.base_outcome = self.causal_graph.calculate_outcome({}) | |
| def set_perturbation_score(self, relation_id: str, score: float) -> None: | |
| """ | |
| Set the perturbation score for a specific relation ID. | |
| This allows explicitly setting scores from external sources (like database queries). | |
| Args: | |
| relation_id: The ID of the relation to set the score for | |
| score: The perturbation score value (typically between 0 and 1) | |
| """ | |
| # Update the relation_outcomes in the causal graph | |
| self.causal_graph.relation_outcomes[relation_id] = float(score) | |
| def calculate_ace(self) -> Dict[str, float]: | |
| """ | |
| Calculate Average Causal Effect (ACE) for each entity and relation. | |
| Returns: | |
| Dictionary mapping IDs to their ACE scores | |
| """ | |
| ace_scores = {} | |
| # Calculate ACE for relations | |
| for rel_id in self.causal_graph.relation_ids: | |
| if rel_id in self.causal_graph.relation_outcomes: | |
| # Use the actual perturbation value from the outcomes | |
| perturbed_outcome = self.causal_graph.calculate_outcome({rel_id: self.causal_graph.relation_outcomes[rel_id]}) | |
| ace_scores[rel_id] = perturbed_outcome - self.base_outcome | |
| else: | |
| # Default to maximum perturbation (1.0) if no value is available | |
| perturbed_outcome = self.causal_graph.calculate_outcome({rel_id: 1.0}) | |
| ace_scores[rel_id] = perturbed_outcome - self.base_outcome | |
| # Calculate ACE for entities | |
| for entity_id in self.causal_graph.entity_ids: | |
| # Default to maximum perturbation (1.0) for entities | |
| perturbed_outcome = self.causal_graph.calculate_outcome({entity_id: 1.0}) | |
| ace_scores[entity_id] = perturbed_outcome - self.base_outcome | |
| return ace_scores | |
| def calculate_shapley_values(self) -> Dict[str, float]: | |
| """ | |
| Calculate Shapley values to fairly attribute causal effects. | |
| Uses sampling for approximation with larger graphs. | |
| Returns: | |
| Dictionary mapping IDs to their Shapley values | |
| """ | |
| # Combine entities and relations as "players" in the Shapley calculation | |
| all_ids = self.causal_graph.entity_ids + self.causal_graph.relation_ids | |
| shapley_values = {id_: 0.0 for id_ in all_ids} | |
| # Generate random permutations for Shapley approximation | |
| for _ in range(self.n_shapley_samples): | |
| perm = random.sample(all_ids, len(all_ids)) | |
| current_set = {} # Empty dictionary instead of empty set | |
| current_outcome = self.base_outcome | |
| for id_ in perm: | |
| # Determine perturbation value to use | |
| if id_ in self.causal_graph.relation_outcomes: | |
| pert_value = self.causal_graph.relation_outcomes[id_] | |
| else: | |
| pert_value = 1.0 # Default to maximum perturbation | |
| # Add current ID to the coalition with its perturbation value | |
| new_set = current_set.copy() | |
| new_set[id_] = pert_value | |
| new_outcome = self.causal_graph.calculate_outcome(new_set) | |
| # Calculate marginal contribution | |
| marginal = new_outcome - current_outcome | |
| shapley_values[id_] += marginal | |
| # Update for next iteration | |
| current_outcome = new_outcome | |
| current_set = new_set | |
| # Normalize the values | |
| for id_ in shapley_values: | |
| shapley_values[id_] /= self.n_shapley_samples | |
| return shapley_values | |
| def analyze(self) -> Tuple[Dict[str, float], Dict[str, float]]: | |
| """ | |
| Perform complete causal analysis. | |
| Returns: | |
| Tuple of (ACE scores, Shapley values) | |
| """ | |
| ace_scores = self.calculate_ace() | |
| shapley_values = self.calculate_shapley_values() | |
| return ace_scores, shapley_values | |
| def enrich_knowledge_graph(kg: Dict, ace_scores: Dict[str, float], | |
| shapley_values: Dict[str, float]) -> Dict: | |
| """ | |
| Enrich the knowledge graph with causal attribution scores. | |
| Args: | |
| kg: Original knowledge graph | |
| ace_scores: Dictionary of ACE scores | |
| shapley_values: Dictionary of Shapley values | |
| Returns: | |
| Enriched knowledge graph | |
| """ | |
| enriched_kg = copy.deepcopy(kg) | |
| # Add scores to entities | |
| for entity in enriched_kg["entities"]: | |
| entity_id = entity["id"] | |
| entity["causal_attribution"] = { | |
| "ACE": ace_scores.get(entity_id, 0), | |
| "Shapley": shapley_values.get(entity_id, 0) | |
| } | |
| # Add scores to relations | |
| for relation in enriched_kg["relations"]: | |
| relation_id = relation["id"] | |
| relation["causal_attribution"] = { | |
| "ACE": ace_scores.get(relation_id, 0), | |
| "Shapley": shapley_values.get(relation_id, 0) | |
| } | |
| return enriched_kg | |
| def generate_summary_report(ace_scores: Dict[str, float], | |
| shapley_values: Dict[str, float], | |
| kg: Dict) -> List[Dict]: | |
| """ | |
| Generate a summary report of causal attributions. | |
| Args: | |
| ace_scores: Dictionary of ACE scores | |
| shapley_values: Dictionary of Shapley values | |
| kg: Knowledge graph | |
| Returns: | |
| List of attribution data for each entity/relation | |
| """ | |
| entity_ids = [entity["id"] for entity in kg["entities"]] | |
| report = [] | |
| for id_ in ace_scores: | |
| if id_ in entity_ids: | |
| type_ = "entity" | |
| else: | |
| type_ = "relation" | |
| report.append({ | |
| "id": id_, | |
| "ACE": ace_scores.get(id_, 0), | |
| "Shapley": shapley_values.get(id_, 0), | |
| "type": type_ | |
| }) | |
| # Sort by Shapley value to highlight most important factors | |
| report.sort(key=lambda x: abs(x["Shapley"]), reverse=True) | |
| return report |