from collections import defaultdict import random import json import copy import numpy as np import os from typing import Dict, Set, List, Tuple, Any, Optional, Union from datetime import datetime from tqdm import tqdm import logging import pandas as pd # Configure logging for this module logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') # Import all causal analysis methods from .graph_analysis import ( CausalGraph, CausalAnalyzer as GraphAnalyzer, enrich_knowledge_graph as enrich_graph, generate_summary_report ) from .influence_analysis import ( analyze_component_influence, print_feature_importance, evaluate_model, identify_key_components, print_component_groups ) from .dowhy_analysis import ( analyze_components_with_dowhy, run_dowhy_analysis ) from .confounders.basic_detection import ( detect_confounders, analyze_confounder_impact, run_confounder_analysis ) from .confounders.multi_signal_detection import ( run_mscd_analysis ) from .component_analysis import ( calculate_average_treatment_effect, granger_causality_test, compute_causal_effect_strength ) from .utils.dataframe_builder import create_component_influence_dataframe def analyze_causal_effects(analysis_data: Dict[str, Any], methods: Optional[List[str]] = None) -> Dict[str, Any]: """ Pure function to run causal analysis for a given analysis data. Args: analysis_data: Dictionary containing all data needed for analysis methods: List of analysis methods to use ('graph', 'component', 'dowhy', 'confounder', 'mscd', 'ate') If None, all methods will be used Returns: Dictionary containing analysis results for each method """ available_methods = ['graph', 'component', 'dowhy', 'confounder', 'mscd', 'ate'] if methods is None: methods = available_methods results = {} # Check if analysis_data contains error if "error" in analysis_data: return analysis_data # Run each analysis method with the pre-filtered data for method in tqdm(methods, desc="Running causal analysis"): try: result_dict = None # Initialize result_dict for this iteration if method == 'graph': result_dict = _analyze_graph(analysis_data) results['graph'] = result_dict elif method == 'component': result_dict = _analyze_component(analysis_data) results['component'] = result_dict elif method == 'dowhy': result_dict = _analyze_dowhy(analysis_data) results['dowhy'] = result_dict elif method == 'confounder': result_dict = _analyze_confounder(analysis_data) results['confounder'] = result_dict elif method == 'mscd': result_dict = _analyze_mscd(analysis_data) results['mscd'] = result_dict elif method == 'ate': result_dict = _analyze_component_ate(analysis_data) results['ate'] = result_dict else: logger.warning(f"Unknown analysis method specified: {method}") continue # Skip to next method # Check for errors returned by the analysis method itself if result_dict and isinstance(result_dict, dict) and "error" in result_dict: logger.error(f"Error explicitly returned by {method} analysis: {result_dict['error']}") results[method] = result_dict # Store the error result except Exception as e: # Log error specific to this method's execution block logger.error(f"Exception caught during {method} analysis: {repr(e)}") results[method] = {"error": repr(e)} # Store the exception representation return results def _create_component_dataframe(analysis_data: Dict) -> pd.DataFrame: """ Create a DataFrame for component analysis from the pre-filtered data. Args: analysis_data: Pre-filtered analysis data containing perturbation tests and dependencies Returns: DataFrame with component features and perturbation scores """ perturbation_tests = analysis_data["perturbation_tests"] dependencies_map = analysis_data["dependencies_map"] # Build a matrix of features (from dependencies) and perturbation scores rows = [] # Track all unique entity and relation IDs all_entity_ids = set() all_relation_ids = set() # First pass: identify all unique entities and relations across all dependencies for test in perturbation_tests: pr_id = test["prompt_reconstruction_id"] dependencies = dependencies_map.get(pr_id, {}) # Skip if dependencies not found or not a dictionary if not dependencies or not isinstance(dependencies, dict): continue # Extract entity and relation dependencies entity_deps = dependencies.get("entities", []) relation_deps = dependencies.get("relations", []) # Add to our tracking sets if isinstance(entity_deps, list): all_entity_ids.update(entity_deps) if isinstance(relation_deps, list): all_relation_ids.update(relation_deps) # Second pass: create rows with binary features for test in perturbation_tests: pr_id = test["prompt_reconstruction_id"] dependencies = dependencies_map.get(pr_id, {}) # Skip if dependencies not found or not a dictionary if not dependencies or not isinstance(dependencies, dict): continue # Extract entity and relation dependencies entity_deps = dependencies.get("entities", []) relation_deps = dependencies.get("relations", []) # Ensure they are lists if not isinstance(entity_deps, list): entity_deps = [] if not isinstance(relation_deps, list): relation_deps = [] # Create row with perturbation score row = {"perturbation": test["perturbation_score"]} # Add binary features for entities for entity_id in all_entity_ids: row[f"entity_{entity_id}"] = 1 if entity_id in entity_deps else 0 # Add binary features for relations for relation_id in all_relation_ids: row[f"relation_{relation_id}"] = 1 if relation_id in relation_deps else 0 rows.append(row) # Create the DataFrame df = pd.DataFrame(rows) # If no rows with features were created, return an empty DataFrame if df.empty: logger.warning("No rows with features could be created from the dependencies") return pd.DataFrame() return df def _analyze_graph(analysis_data: Dict) -> Dict[str, Any]: """ Perform graph-based causal analysis using pre-filtered data. Args: analysis_data: Pre-filtered analysis data containing knowledge graph and perturbation scores """ # Use the knowledge graph structure but only consider relations with # perturbation scores from our perturbation_set_id kg_data = analysis_data["knowledge_graph"] perturbation_scores = analysis_data["perturbation_scores"] # Modify the graph to only include relations with perturbation scores filtered_kg = copy.deepcopy(kg_data) filtered_kg["relations"] = [ rel for rel in filtered_kg.get("relations", []) if rel.get("id") in perturbation_scores ] # Create and analyze the causal graph causal_graph = CausalGraph(filtered_kg) analyzer = GraphAnalyzer(causal_graph) # Add perturbation scores to the analyzer for relation_id, score in perturbation_scores.items(): analyzer.set_perturbation_score(relation_id, score) ace_scores, shapley_values = analyzer.analyze() return { "scores": { "ACE": ace_scores, "Shapley": shapley_values }, "metadata": { "method": "graph", "relations_analyzed": len(filtered_kg["relations"]) } } def _analyze_component(analysis_data: Dict) -> Dict[str, Any]: """ Perform component-based causal analysis using pre-filtered data. Args: analysis_data: Pre-filtered analysis data containing perturbation tests and dependencies """ # Create DataFrame from pre-filtered data df = _create_component_dataframe(analysis_data) if df is None or df.empty: logger.error("Failed to create or empty DataFrame for component analysis") return { "error": "Failed to create or empty DataFrame for component analysis", "scores": {}, "metadata": {"method": "component"} } # Check if perturbation column exists and has variance if 'perturbation' not in df.columns: logger.error("'perturbation' column missing from DataFrame.") return { "error": "'perturbation' column missing from DataFrame.", "scores": {}, "metadata": {"method": "component"} } # Run the analysis, which now returns the feature columns used rf_model, feature_importance, feature_cols = analyze_component_influence(df) # Evaluate model using the correct feature columns if feature_cols: # Only evaluate if features were actually used metrics = evaluate_model(rf_model, df[feature_cols], df['perturbation']) else: # Handle case where no features were used (e.g., no variance) metrics = {'mse': 0.0, 'rmse': 0.0, 'r2': 1.0 if df['perturbation'].std() == 0 else 0.0} # Identify key components based on absolute importance key_components = [ feature for feature, importance in feature_importance.items() if abs(importance) >= 0.01 ] return { "scores": { "Feature_Importance": feature_importance, "Model_Metrics": metrics, "Key_Components": key_components }, "metadata": { "method": "component", "model_type": "LinearModel", "rows_analyzed": len(df) } } def _analyze_dowhy(analysis_data: Dict) -> Dict[str, Any]: """ Perform DoWhy-based causal analysis using pre-filtered data. Args: analysis_data: Pre-filtered analysis data containing perturbation tests and dependencies """ # Create DataFrame from pre-filtered data (reusing the same function as component analysis) df = _create_component_dataframe(analysis_data) if df is None or df.empty: return { "error": "Failed to create DataFrame for DoWhy analysis", "scores": {}, "metadata": {"method": "dowhy"} } # Get component columns (features) components = [col for col in df.columns if col.startswith(('entity_', 'relation_'))] if not components: return { "error": "No component features found for DoWhy analysis", "scores": {}, "metadata": {"method": "dowhy"} } # Check for potential confounders before analysis # A confounder may be present if two variables appear together more frequently than would be expected by chance confounders = {} co_occurrence_threshold = 1.5 for i, comp1 in enumerate(components): for comp2 in components[i+1:]: # Count co-occurrences both_present = ((df[comp1] == 1) & (df[comp2] == 1)).sum() comp1_present = (df[comp1] == 1).sum() comp2_present = (df[comp2] == 1).sum() if comp1_present > 0 and comp2_present > 0: # Expected co-occurrence under independence expected = (comp1_present * comp2_present) / len(df) if expected > 0: co_occurrence_ratio = both_present / expected if co_occurrence_ratio > co_occurrence_threshold: if comp1 not in confounders: confounders[comp1] = [] confounders[comp1].append({ "confounder": comp2, "co_occurrence_ratio": co_occurrence_ratio, "both_present": both_present, "expected": expected }) # Run DoWhy analysis with all components logger.info(f"Running DoWhy analysis with all {len(components)} components") results = analyze_components_with_dowhy(df, components) # Extract effect estimates and refutation results effect_estimates = {r['component']: r.get('effect_estimate', 0) for r in results} refutation_results = {r['component']: r.get('refutation_results', []) for r in results} # Extract interaction effects interaction_effects = {} for result in results: component = result.get('component') if component and 'interacts_with' in result: interaction_effects[component] = result['interacts_with'] # Also check for directly detected interaction effects if component and 'interaction_effects' in result: # If no existing entry, create one if component not in interaction_effects: interaction_effects[component] = [] # Add directly detected interactions for interaction in result['interaction_effects']: interaction_component = interaction['component'] interaction_coef = interaction['interaction_coefficient'] interaction_effects[component].append({ 'component': interaction_component, 'interaction_coefficient': interaction_coef }) return { "scores": { "Effect_Estimate": effect_estimates, "Refutation_Results": refutation_results, "Interaction_Effects": interaction_effects, "Confounders": confounders }, "metadata": { "method": "dowhy", "analysis_type": "backdoor.linear_regression", "rows_analyzed": len(df), "components_analyzed": len(components) } } def _analyze_confounder(analysis_data: Dict) -> Dict[str, Any]: """ Perform confounder detection analysis using pre-filtered data. Args: analysis_data: Pre-filtered analysis data containing perturbation tests and dependencies """ # Create DataFrame from pre-filtered data (reusing the same function as component analysis) df = _create_component_dataframe(analysis_data) if df is None or df.empty: return { "error": "Failed to create DataFrame for confounder analysis", "scores": {}, "metadata": {"method": "confounder"} } # Get component columns (features) components = [col for col in df.columns if col.startswith(('entity_', 'relation_'))] if not components: return { "error": "No component features found for confounder analysis", "scores": {}, "metadata": {"method": "confounder"} } # Define specific confounder pairs to check in the test data specific_confounder_pairs = [ ("relation_relation-9", "relation_relation-10"), ("entity_input-001", "entity_human-user-001") ] # Run the confounder analysis logger.info(f"Running confounder detection analysis with {len(components)} components") confounder_results = run_confounder_analysis( df, outcome_var="perturbation", cooccurrence_threshold=1.2, min_occurrences=2, specific_confounder_pairs=specific_confounder_pairs ) return { "scores": { "Confounders": confounder_results.get("confounders", {}), "Impact_Analysis": confounder_results.get("impact_analysis", {}), "Summary": confounder_results.get("summary", {}) }, "metadata": { "method": "confounder", "rows_analyzed": len(df), "components_analyzed": len(components) } } def _analyze_mscd(analysis_data: Dict) -> Dict[str, Any]: """ Perform Multi-Signal Confounder Detection (MSCD) analysis using pre-filtered data. Args: analysis_data: Pre-filtered analysis data containing perturbation tests and dependencies """ # Create DataFrame from pre-filtered data (reusing the same function as component analysis) df = _create_component_dataframe(analysis_data) if df is None or df.empty: return { "error": "Failed to create DataFrame for MSCD analysis", "scores": {}, "metadata": {"method": "mscd"} } # Get component columns (features) components = [col for col in df.columns if col.startswith(('entity_', 'relation_'))] if not components: return { "error": "No component features found for MSCD analysis", "scores": {}, "metadata": {"method": "mscd"} } # Define specific confounder pairs to check specific_confounder_pairs = [ ("relation_relation-9", "relation_relation-10"), ("entity_input-001", "entity_human-user-001") ] # Run MSCD analysis logger.info(f"Running Multi-Signal Confounder Detection with {len(components)} components") mscd_results = run_mscd_analysis( df, outcome_var="perturbation", specific_confounder_pairs=specific_confounder_pairs ) return { "scores": { "Confounders": mscd_results.get("combined_confounders", {}), "Method_Results": mscd_results.get("method_results", {}), "Summary": mscd_results.get("summary", {}) }, "metadata": { "method": "mscd", "rows_analyzed": len(df), "components_analyzed": len(components) } } def _analyze_component_ate(analysis_data: Dict) -> Dict[str, Any]: """ Perform Component Average Treatment Effect (ATE) analysis using pre-filtered data. Args: analysis_data: Pre-filtered analysis data containing perturbation tests and dependencies """ try: logger.info("Starting Component ATE analysis") # Create component influence DataFrame df = _create_component_dataframe(analysis_data) if df is None or df.empty: logger.error("Failed to create component DataFrame for ATE analysis") return {"error": "Failed to create component DataFrame"} # Get component columns component_cols = [col for col in df.columns if col.startswith(("entity_", "relation_"))] if not component_cols: logger.error("No component features found in DataFrame for ATE analysis") return {"error": "No component features found"} # 1. Compute causal effect strengths (ATE) logger.info("Computing causal effect strengths (ATE)") effect_strengths = compute_causal_effect_strength(df) # Sort components by absolute effect strength sorted_effects = sorted(effect_strengths.items(), key=lambda x: abs(x[1]), reverse=True) # 2. Run Granger causality tests on top components logger.info("Running Granger causality tests on top components") granger_results = {} top_components = [comp for comp, _ in sorted_effects[:min(10, len(sorted_effects))]] for component in top_components: try: granger_result = granger_causality_test(df, component) granger_results[component] = granger_result except Exception as e: logger.warning(f"Error in Granger causality test for {component}: {e}") granger_results[component] = { "f_statistic": 0.0, "p_value": 1.0, "causal_direction": "error" } # 3. Calculate ATE for all components logger.info("Computing ATE for all components") ate_results = {} for component in component_cols: try: ate_result = calculate_average_treatment_effect(df, component) ate_results[component] = ate_result except Exception as e: logger.warning(f"Error computing ATE for {component}: {e}") ate_results[component] = { "ate": 0.0, "std_error": 0.0, "t_statistic": 0.0, "p_value": 1.0 } return { "scores": { "Effect_Strengths": effect_strengths, "Granger_Results": granger_results, "ATE_Results": ate_results }, "metadata": { "method": "ate", "components_analyzed": len(component_cols), "top_components_tested": len(top_components), "rows_analyzed": len(df) } } except Exception as e: logger.error(f"Error in Component ATE analysis: {str(e)}") return {"error": f"Component ATE analysis failed: {str(e)}"} def enrich_knowledge_graph(kg_data: Dict, results: Dict[str, Any]) -> Dict: """ Enrich knowledge graph with causal attribution scores from all methods. Args: kg_data: Original knowledge graph data results: Analysis results from all methods Returns: Enriched knowledge graph with causal attributions from all methods """ if not results: raise ValueError("No analysis results available") enriched_kg = copy.deepcopy(kg_data) # Add causal attribution to entities for entity in enriched_kg["entities"]: entity_id = entity["id"] entity["causal_attribution"] = {} # Add scores from each method for method, result in results.items(): if "error" in result: continue if method == "graph": entity["causal_attribution"]["graph"] = { "ACE": result["scores"]["ACE"].get(entity_id, 0), "Shapley": result["scores"]["Shapley"].get(entity_id, 0) } elif method == "component": entity["causal_attribution"]["component"] = { "Feature_Importance": result["scores"]["Feature_Importance"].get(entity_id, 0), "Is_Key_Component": entity_id in result["scores"]["Key_Components"] } elif method == "dowhy": entity["causal_attribution"]["dowhy"] = { "Effect_Estimate": result["scores"]["Effect_Estimate"].get(entity_id, 0), "Refutation_Results": result["scores"]["Refutation_Results"].get(entity_id, []) } # Add causal attribution to relations for relation in enriched_kg["relations"]: relation_id = relation["id"] relation["causal_attribution"] = {} # Add scores from each method for method, result in results.items(): if "error" in result: continue if method == "graph": relation["causal_attribution"]["graph"] = { "ACE": result["scores"]["ACE"].get(relation_id, 0), "Shapley": result["scores"]["Shapley"].get(relation_id, 0) } elif method == "component": relation["causal_attribution"]["component"] = { "Feature_Importance": result["scores"]["Feature_Importance"].get(relation_id, 0), "Is_Key_Component": relation_id in result["scores"]["Key_Components"] } elif method == "dowhy": relation["causal_attribution"]["dowhy"] = { "Effect_Estimate": result["scores"]["Effect_Estimate"].get(relation_id, 0), "Refutation_Results": result["scores"]["Refutation_Results"].get(relation_id, []) } return enriched_kg def generate_report(kg_data: Dict, results: Dict[str, Any]) -> Dict[str, Any]: """ Generate a comprehensive report of causal analysis results. Args: kg_data: Original knowledge graph data results: Analysis results from all methods Returns: Dictionary containing comprehensive analysis report """ if not results: return {"error": "No analysis results available for report generation"} report = { "summary": { "total_entities": len(kg_data.get("entities", [])), "total_relations": len(kg_data.get("relations", [])), "methods_used": list(results.keys()), "successful_methods": [method for method in results.keys() if "error" not in results[method]], "failed_methods": [method for method in results.keys() if "error" in results[method]] }, "method_results": {}, "key_findings": [], "recommendations": [] } # Compile results from each method for method, result in results.items(): if "error" in result: report["method_results"][method] = {"status": "failed", "error": result["error"]} continue report["method_results"][method] = { "status": "success", "scores": result.get("scores", {}), "metadata": result.get("metadata", {}) } # Generate key findings if "graph" in results and "error" not in results["graph"]: ace_scores = results["graph"]["scores"].get("ACE", {}) if ace_scores: top_ace = max(ace_scores.items(), key=lambda x: abs(x[1])) report["key_findings"].append(f"Strongest causal effect detected on {top_ace[0]} (ACE: {top_ace[1]:.3f})") if "component" in results and "error" not in results["component"]: key_components = results["component"]["scores"].get("Key_Components", []) if key_components: report["key_findings"].append(f"Key causal components identified: {', '.join(key_components[:5])}") # Generate recommendations if len(report["summary"]["failed_methods"]) > 0: report["recommendations"].append("Consider investigating failed analysis methods for data quality issues") if report["summary"]["total_relations"] < 10: report["recommendations"].append("Small knowledge graph may limit causal analysis accuracy") return report