AgentGraph / agentgraph /causal /causal_interface.py
wu981526092's picture
🚀 Deploy AgentGraph: Complete agent monitoring and knowledge graph system
c2ea5ed
from collections import defaultdict
import random
import json
import copy
import numpy as np
import os
from typing import Dict, Set, List, Tuple, Any, Optional, Union
from datetime import datetime
from tqdm import tqdm
import logging
import pandas as pd
# Configure logging for this module
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# Import all causal analysis methods
from .graph_analysis import (
CausalGraph,
CausalAnalyzer as GraphAnalyzer,
enrich_knowledge_graph as enrich_graph,
generate_summary_report
)
from .influence_analysis import (
analyze_component_influence,
print_feature_importance,
evaluate_model,
identify_key_components,
print_component_groups
)
from .dowhy_analysis import (
analyze_components_with_dowhy,
run_dowhy_analysis
)
from .confounders.basic_detection import (
detect_confounders,
analyze_confounder_impact,
run_confounder_analysis
)
from .confounders.multi_signal_detection import (
run_mscd_analysis
)
from .component_analysis import (
calculate_average_treatment_effect,
granger_causality_test,
compute_causal_effect_strength
)
from .utils.dataframe_builder import create_component_influence_dataframe
def analyze_causal_effects(analysis_data: Dict[str, Any], methods: Optional[List[str]] = None) -> Dict[str, Any]:
"""
Pure function to run causal analysis for a given analysis data.
Args:
analysis_data: Dictionary containing all data needed for analysis
methods: List of analysis methods to use ('graph', 'component', 'dowhy', 'confounder', 'mscd', 'ate')
If None, all methods will be used
Returns:
Dictionary containing analysis results for each method
"""
available_methods = ['graph', 'component', 'dowhy', 'confounder', 'mscd', 'ate']
if methods is None:
methods = available_methods
results = {}
# Check if analysis_data contains error
if "error" in analysis_data:
return analysis_data
# Run each analysis method with the pre-filtered data
for method in tqdm(methods, desc="Running causal analysis"):
try:
result_dict = None # Initialize result_dict for this iteration
if method == 'graph':
result_dict = _analyze_graph(analysis_data)
results['graph'] = result_dict
elif method == 'component':
result_dict = _analyze_component(analysis_data)
results['component'] = result_dict
elif method == 'dowhy':
result_dict = _analyze_dowhy(analysis_data)
results['dowhy'] = result_dict
elif method == 'confounder':
result_dict = _analyze_confounder(analysis_data)
results['confounder'] = result_dict
elif method == 'mscd':
result_dict = _analyze_mscd(analysis_data)
results['mscd'] = result_dict
elif method == 'ate':
result_dict = _analyze_component_ate(analysis_data)
results['ate'] = result_dict
else:
logger.warning(f"Unknown analysis method specified: {method}")
continue # Skip to next method
# Check for errors returned by the analysis method itself
if result_dict and isinstance(result_dict, dict) and "error" in result_dict:
logger.error(f"Error explicitly returned by {method} analysis: {result_dict['error']}")
results[method] = result_dict # Store the error result
except Exception as e:
# Log error specific to this method's execution block
logger.error(f"Exception caught during {method} analysis: {repr(e)}")
results[method] = {"error": repr(e)} # Store the exception representation
return results
def _create_component_dataframe(analysis_data: Dict) -> pd.DataFrame:
"""
Create a DataFrame for component analysis from the pre-filtered data.
Args:
analysis_data: Pre-filtered analysis data containing perturbation tests and dependencies
Returns:
DataFrame with component features and perturbation scores
"""
perturbation_tests = analysis_data["perturbation_tests"]
dependencies_map = analysis_data["dependencies_map"]
# Build a matrix of features (from dependencies) and perturbation scores
rows = []
# Track all unique entity and relation IDs
all_entity_ids = set()
all_relation_ids = set()
# First pass: identify all unique entities and relations across all dependencies
for test in perturbation_tests:
pr_id = test["prompt_reconstruction_id"]
dependencies = dependencies_map.get(pr_id, {})
# Skip if dependencies not found or not a dictionary
if not dependencies or not isinstance(dependencies, dict):
continue
# Extract entity and relation dependencies
entity_deps = dependencies.get("entities", [])
relation_deps = dependencies.get("relations", [])
# Add to our tracking sets
if isinstance(entity_deps, list):
all_entity_ids.update(entity_deps)
if isinstance(relation_deps, list):
all_relation_ids.update(relation_deps)
# Second pass: create rows with binary features
for test in perturbation_tests:
pr_id = test["prompt_reconstruction_id"]
dependencies = dependencies_map.get(pr_id, {})
# Skip if dependencies not found or not a dictionary
if not dependencies or not isinstance(dependencies, dict):
continue
# Extract entity and relation dependencies
entity_deps = dependencies.get("entities", [])
relation_deps = dependencies.get("relations", [])
# Ensure they are lists
if not isinstance(entity_deps, list):
entity_deps = []
if not isinstance(relation_deps, list):
relation_deps = []
# Create row with perturbation score
row = {"perturbation": test["perturbation_score"]}
# Add binary features for entities
for entity_id in all_entity_ids:
row[f"entity_{entity_id}"] = 1 if entity_id in entity_deps else 0
# Add binary features for relations
for relation_id in all_relation_ids:
row[f"relation_{relation_id}"] = 1 if relation_id in relation_deps else 0
rows.append(row)
# Create the DataFrame
df = pd.DataFrame(rows)
# If no rows with features were created, return an empty DataFrame
if df.empty:
logger.warning("No rows with features could be created from the dependencies")
return pd.DataFrame()
return df
def _analyze_graph(analysis_data: Dict) -> Dict[str, Any]:
"""
Perform graph-based causal analysis using pre-filtered data.
Args:
analysis_data: Pre-filtered analysis data containing knowledge graph
and perturbation scores
"""
# Use the knowledge graph structure but only consider relations with
# perturbation scores from our perturbation_set_id
kg_data = analysis_data["knowledge_graph"]
perturbation_scores = analysis_data["perturbation_scores"]
# Modify the graph to only include relations with perturbation scores
filtered_kg = copy.deepcopy(kg_data)
filtered_kg["relations"] = [
rel for rel in filtered_kg.get("relations", [])
if rel.get("id") in perturbation_scores
]
# Create and analyze the causal graph
causal_graph = CausalGraph(filtered_kg)
analyzer = GraphAnalyzer(causal_graph)
# Add perturbation scores to the analyzer
for relation_id, score in perturbation_scores.items():
analyzer.set_perturbation_score(relation_id, score)
ace_scores, shapley_values = analyzer.analyze()
return {
"scores": {
"ACE": ace_scores,
"Shapley": shapley_values
},
"metadata": {
"method": "graph",
"relations_analyzed": len(filtered_kg["relations"])
}
}
def _analyze_component(analysis_data: Dict) -> Dict[str, Any]:
"""
Perform component-based causal analysis using pre-filtered data.
Args:
analysis_data: Pre-filtered analysis data containing perturbation tests and dependencies
"""
# Create DataFrame from pre-filtered data
df = _create_component_dataframe(analysis_data)
if df is None or df.empty:
logger.error("Failed to create or empty DataFrame for component analysis")
return {
"error": "Failed to create or empty DataFrame for component analysis",
"scores": {},
"metadata": {"method": "component"}
}
# Check if perturbation column exists and has variance
if 'perturbation' not in df.columns:
logger.error("'perturbation' column missing from DataFrame.")
return {
"error": "'perturbation' column missing from DataFrame.",
"scores": {},
"metadata": {"method": "component"}
}
# Run the analysis, which now returns the feature columns used
rf_model, feature_importance, feature_cols = analyze_component_influence(df)
# Evaluate model using the correct feature columns
if feature_cols: # Only evaluate if features were actually used
metrics = evaluate_model(rf_model, df[feature_cols], df['perturbation'])
else: # Handle case where no features were used (e.g., no variance)
metrics = {'mse': 0.0, 'rmse': 0.0, 'r2': 1.0 if df['perturbation'].std() == 0 else 0.0}
# Identify key components based on absolute importance
key_components = [
feature for feature, importance in feature_importance.items()
if abs(importance) >= 0.01
]
return {
"scores": {
"Feature_Importance": feature_importance,
"Model_Metrics": metrics,
"Key_Components": key_components
},
"metadata": {
"method": "component",
"model_type": "LinearModel",
"rows_analyzed": len(df)
}
}
def _analyze_dowhy(analysis_data: Dict) -> Dict[str, Any]:
"""
Perform DoWhy-based causal analysis using pre-filtered data.
Args:
analysis_data: Pre-filtered analysis data containing perturbation tests and dependencies
"""
# Create DataFrame from pre-filtered data (reusing the same function as component analysis)
df = _create_component_dataframe(analysis_data)
if df is None or df.empty:
return {
"error": "Failed to create DataFrame for DoWhy analysis",
"scores": {},
"metadata": {"method": "dowhy"}
}
# Get component columns (features)
components = [col for col in df.columns if col.startswith(('entity_', 'relation_'))]
if not components:
return {
"error": "No component features found for DoWhy analysis",
"scores": {},
"metadata": {"method": "dowhy"}
}
# Check for potential confounders before analysis
# A confounder may be present if two variables appear together more frequently than would be expected by chance
confounders = {}
co_occurrence_threshold = 1.5
for i, comp1 in enumerate(components):
for comp2 in components[i+1:]:
# Count co-occurrences
both_present = ((df[comp1] == 1) & (df[comp2] == 1)).sum()
comp1_present = (df[comp1] == 1).sum()
comp2_present = (df[comp2] == 1).sum()
if comp1_present > 0 and comp2_present > 0:
# Expected co-occurrence under independence
expected = (comp1_present * comp2_present) / len(df)
if expected > 0:
co_occurrence_ratio = both_present / expected
if co_occurrence_ratio > co_occurrence_threshold:
if comp1 not in confounders:
confounders[comp1] = []
confounders[comp1].append({
"confounder": comp2,
"co_occurrence_ratio": co_occurrence_ratio,
"both_present": both_present,
"expected": expected
})
# Run DoWhy analysis with all components
logger.info(f"Running DoWhy analysis with all {len(components)} components")
results = analyze_components_with_dowhy(df, components)
# Extract effect estimates and refutation results
effect_estimates = {r['component']: r.get('effect_estimate', 0) for r in results}
refutation_results = {r['component']: r.get('refutation_results', []) for r in results}
# Extract interaction effects
interaction_effects = {}
for result in results:
component = result.get('component')
if component and 'interacts_with' in result:
interaction_effects[component] = result['interacts_with']
# Also check for directly detected interaction effects
if component and 'interaction_effects' in result:
# If no existing entry, create one
if component not in interaction_effects:
interaction_effects[component] = []
# Add directly detected interactions
for interaction in result['interaction_effects']:
interaction_component = interaction['component']
interaction_coef = interaction['interaction_coefficient']
interaction_effects[component].append({
'component': interaction_component,
'interaction_coefficient': interaction_coef
})
return {
"scores": {
"Effect_Estimate": effect_estimates,
"Refutation_Results": refutation_results,
"Interaction_Effects": interaction_effects,
"Confounders": confounders
},
"metadata": {
"method": "dowhy",
"analysis_type": "backdoor.linear_regression",
"rows_analyzed": len(df),
"components_analyzed": len(components)
}
}
def _analyze_confounder(analysis_data: Dict) -> Dict[str, Any]:
"""
Perform confounder detection analysis using pre-filtered data.
Args:
analysis_data: Pre-filtered analysis data containing perturbation tests and dependencies
"""
# Create DataFrame from pre-filtered data (reusing the same function as component analysis)
df = _create_component_dataframe(analysis_data)
if df is None or df.empty:
return {
"error": "Failed to create DataFrame for confounder analysis",
"scores": {},
"metadata": {"method": "confounder"}
}
# Get component columns (features)
components = [col for col in df.columns if col.startswith(('entity_', 'relation_'))]
if not components:
return {
"error": "No component features found for confounder analysis",
"scores": {},
"metadata": {"method": "confounder"}
}
# Define specific confounder pairs to check in the test data
specific_confounder_pairs = [
("relation_relation-9", "relation_relation-10"),
("entity_input-001", "entity_human-user-001")
]
# Run the confounder analysis
logger.info(f"Running confounder detection analysis with {len(components)} components")
confounder_results = run_confounder_analysis(
df,
outcome_var="perturbation",
cooccurrence_threshold=1.2,
min_occurrences=2,
specific_confounder_pairs=specific_confounder_pairs
)
return {
"scores": {
"Confounders": confounder_results.get("confounders", {}),
"Impact_Analysis": confounder_results.get("impact_analysis", {}),
"Summary": confounder_results.get("summary", {})
},
"metadata": {
"method": "confounder",
"rows_analyzed": len(df),
"components_analyzed": len(components)
}
}
def _analyze_mscd(analysis_data: Dict) -> Dict[str, Any]:
"""
Perform Multi-Signal Confounder Detection (MSCD) analysis using pre-filtered data.
Args:
analysis_data: Pre-filtered analysis data containing perturbation tests and dependencies
"""
# Create DataFrame from pre-filtered data (reusing the same function as component analysis)
df = _create_component_dataframe(analysis_data)
if df is None or df.empty:
return {
"error": "Failed to create DataFrame for MSCD analysis",
"scores": {},
"metadata": {"method": "mscd"}
}
# Get component columns (features)
components = [col for col in df.columns if col.startswith(('entity_', 'relation_'))]
if not components:
return {
"error": "No component features found for MSCD analysis",
"scores": {},
"metadata": {"method": "mscd"}
}
# Define specific confounder pairs to check
specific_confounder_pairs = [
("relation_relation-9", "relation_relation-10"),
("entity_input-001", "entity_human-user-001")
]
# Run MSCD analysis
logger.info(f"Running Multi-Signal Confounder Detection with {len(components)} components")
mscd_results = run_mscd_analysis(
df,
outcome_var="perturbation",
specific_confounder_pairs=specific_confounder_pairs
)
return {
"scores": {
"Confounders": mscd_results.get("combined_confounders", {}),
"Method_Results": mscd_results.get("method_results", {}),
"Summary": mscd_results.get("summary", {})
},
"metadata": {
"method": "mscd",
"rows_analyzed": len(df),
"components_analyzed": len(components)
}
}
def _analyze_component_ate(analysis_data: Dict) -> Dict[str, Any]:
"""
Perform Component Average Treatment Effect (ATE) analysis using pre-filtered data.
Args:
analysis_data: Pre-filtered analysis data containing perturbation tests and dependencies
"""
try:
logger.info("Starting Component ATE analysis")
# Create component influence DataFrame
df = _create_component_dataframe(analysis_data)
if df is None or df.empty:
logger.error("Failed to create component DataFrame for ATE analysis")
return {"error": "Failed to create component DataFrame"}
# Get component columns
component_cols = [col for col in df.columns if col.startswith(("entity_", "relation_"))]
if not component_cols:
logger.error("No component features found in DataFrame for ATE analysis")
return {"error": "No component features found"}
# 1. Compute causal effect strengths (ATE)
logger.info("Computing causal effect strengths (ATE)")
effect_strengths = compute_causal_effect_strength(df)
# Sort components by absolute effect strength
sorted_effects = sorted(effect_strengths.items(), key=lambda x: abs(x[1]), reverse=True)
# 2. Run Granger causality tests on top components
logger.info("Running Granger causality tests on top components")
granger_results = {}
top_components = [comp for comp, _ in sorted_effects[:min(10, len(sorted_effects))]]
for component in top_components:
try:
granger_result = granger_causality_test(df, component)
granger_results[component] = granger_result
except Exception as e:
logger.warning(f"Error in Granger causality test for {component}: {e}")
granger_results[component] = {
"f_statistic": 0.0,
"p_value": 1.0,
"causal_direction": "error"
}
# 3. Calculate ATE for all components
logger.info("Computing ATE for all components")
ate_results = {}
for component in component_cols:
try:
ate_result = calculate_average_treatment_effect(df, component)
ate_results[component] = ate_result
except Exception as e:
logger.warning(f"Error computing ATE for {component}: {e}")
ate_results[component] = {
"ate": 0.0,
"std_error": 0.0,
"t_statistic": 0.0,
"p_value": 1.0
}
return {
"scores": {
"Effect_Strengths": effect_strengths,
"Granger_Results": granger_results,
"ATE_Results": ate_results
},
"metadata": {
"method": "ate",
"components_analyzed": len(component_cols),
"top_components_tested": len(top_components),
"rows_analyzed": len(df)
}
}
except Exception as e:
logger.error(f"Error in Component ATE analysis: {str(e)}")
return {"error": f"Component ATE analysis failed: {str(e)}"}
def enrich_knowledge_graph(kg_data: Dict, results: Dict[str, Any]) -> Dict:
"""
Enrich knowledge graph with causal attribution scores from all methods.
Args:
kg_data: Original knowledge graph data
results: Analysis results from all methods
Returns:
Enriched knowledge graph with causal attributions from all methods
"""
if not results:
raise ValueError("No analysis results available")
enriched_kg = copy.deepcopy(kg_data)
# Add causal attribution to entities
for entity in enriched_kg["entities"]:
entity_id = entity["id"]
entity["causal_attribution"] = {}
# Add scores from each method
for method, result in results.items():
if "error" in result:
continue
if method == "graph":
entity["causal_attribution"]["graph"] = {
"ACE": result["scores"]["ACE"].get(entity_id, 0),
"Shapley": result["scores"]["Shapley"].get(entity_id, 0)
}
elif method == "component":
entity["causal_attribution"]["component"] = {
"Feature_Importance": result["scores"]["Feature_Importance"].get(entity_id, 0),
"Is_Key_Component": entity_id in result["scores"]["Key_Components"]
}
elif method == "dowhy":
entity["causal_attribution"]["dowhy"] = {
"Effect_Estimate": result["scores"]["Effect_Estimate"].get(entity_id, 0),
"Refutation_Results": result["scores"]["Refutation_Results"].get(entity_id, [])
}
# Add causal attribution to relations
for relation in enriched_kg["relations"]:
relation_id = relation["id"]
relation["causal_attribution"] = {}
# Add scores from each method
for method, result in results.items():
if "error" in result:
continue
if method == "graph":
relation["causal_attribution"]["graph"] = {
"ACE": result["scores"]["ACE"].get(relation_id, 0),
"Shapley": result["scores"]["Shapley"].get(relation_id, 0)
}
elif method == "component":
relation["causal_attribution"]["component"] = {
"Feature_Importance": result["scores"]["Feature_Importance"].get(relation_id, 0),
"Is_Key_Component": relation_id in result["scores"]["Key_Components"]
}
elif method == "dowhy":
relation["causal_attribution"]["dowhy"] = {
"Effect_Estimate": result["scores"]["Effect_Estimate"].get(relation_id, 0),
"Refutation_Results": result["scores"]["Refutation_Results"].get(relation_id, [])
}
return enriched_kg
def generate_report(kg_data: Dict, results: Dict[str, Any]) -> Dict[str, Any]:
"""
Generate a comprehensive report of causal analysis results.
Args:
kg_data: Original knowledge graph data
results: Analysis results from all methods
Returns:
Dictionary containing comprehensive analysis report
"""
if not results:
return {"error": "No analysis results available for report generation"}
report = {
"summary": {
"total_entities": len(kg_data.get("entities", [])),
"total_relations": len(kg_data.get("relations", [])),
"methods_used": list(results.keys()),
"successful_methods": [method for method in results.keys() if "error" not in results[method]],
"failed_methods": [method for method in results.keys() if "error" in results[method]]
},
"method_results": {},
"key_findings": [],
"recommendations": []
}
# Compile results from each method
for method, result in results.items():
if "error" in result:
report["method_results"][method] = {"status": "failed", "error": result["error"]}
continue
report["method_results"][method] = {
"status": "success",
"scores": result.get("scores", {}),
"metadata": result.get("metadata", {})
}
# Generate key findings
if "graph" in results and "error" not in results["graph"]:
ace_scores = results["graph"]["scores"].get("ACE", {})
if ace_scores:
top_ace = max(ace_scores.items(), key=lambda x: abs(x[1]))
report["key_findings"].append(f"Strongest causal effect detected on {top_ace[0]} (ACE: {top_ace[1]:.3f})")
if "component" in results and "error" not in results["component"]:
key_components = results["component"]["scores"].get("Key_Components", [])
if key_components:
report["key_findings"].append(f"Key causal components identified: {', '.join(key_components[:5])}")
# Generate recommendations
if len(report["summary"]["failed_methods"]) > 0:
report["recommendations"].append("Consider investigating failed analysis methods for data quality issues")
if report["summary"]["total_relations"] < 10:
report["recommendations"].append("Small knowledge graph may limit causal analysis accuracy")
return report