Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| DoWhy Causal Component Analysis | |
| This script implements causal inference methods using the DoWhy library to analyze | |
| the causal relationship between knowledge graph components and perturbation scores. | |
| """ | |
| import os | |
| import sys | |
| import pandas as pd | |
| import numpy as np | |
| import argparse | |
| import logging | |
| import json | |
| from typing import Dict, List, Optional, Tuple, Set | |
| from collections import defaultdict | |
| # Import DoWhy | |
| import dowhy | |
| from dowhy import CausalModel | |
| # Import from utils directory | |
| from .utils.dataframe_builder import create_component_influence_dataframe | |
| # Import shared utilities | |
| from .utils.shared_utils import create_mock_perturbation_scores, list_available_components | |
| # Configure logging | |
| logger = logging.getLogger(__name__) | |
| # Suppress DoWhy/info logs by setting their loggers to WARNING or higher | |
| logging.basicConfig(level=logging.CRITICAL, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| # Suppress DoWhy and related noisy loggers | |
| for noisy_logger in [ | |
| "dowhy", | |
| "dowhy.causal_estimator", | |
| "dowhy.causal_model", | |
| "dowhy.causal_refuter", | |
| "dowhy.do_sampler", | |
| "dowhy.identifier", | |
| "dowhy.propensity_score", | |
| "dowhy.utils", | |
| "dowhy.causal_refuter.add_unobserved_common_cause" | |
| ]: | |
| logging.getLogger(noisy_logger).setLevel(logging.WARNING) | |
| # Note: create_mock_perturbation_scores and list_available_components | |
| # moved to utils.shared_utils to avoid duplication | |
| def generate_simple_causal_graph(df: pd.DataFrame, treatment: str, outcome: str) -> str: | |
| """ | |
| Generate a simple causal graph in a format compatible with DoWhy. | |
| Args: | |
| df: DataFrame with features | |
| treatment: Treatment variable name | |
| outcome: Outcome variable name | |
| Returns: | |
| String representation of the causal graph in DoWhy format | |
| """ | |
| # Get component columns (all other variables that could affect both treatment and outcome) | |
| component_cols = [col for col in df.columns if col.startswith(('entity_', 'relation_')) and col != treatment] | |
| # Identify potential confounders by checking correlation patterns with the treatment | |
| confounder_threshold = 0.7 # Correlation threshold to identify potential confounders | |
| potential_confounders = [] | |
| # Calculate correlations between components to identify potential confounders | |
| # A high correlation may indicate a confounder relationship | |
| for component in component_cols: | |
| # Skip if no variance (would result in correlation NaN) | |
| if df[component].std() == 0 or df[treatment].std() == 0: | |
| continue | |
| correlation = df[component].corr(df[treatment]) | |
| if abs(correlation) >= confounder_threshold: | |
| potential_confounders.append(component) | |
| # Create a graph in DOT format | |
| graph = "digraph {" | |
| # Add edges for Treatment -> Outcome | |
| graph += f'"{treatment}" -> "{outcome}";' | |
| # Add edges for identified confounders | |
| for confounder in potential_confounders: | |
| # Confounder affects both treatment and outcome | |
| graph += f'"{confounder}" -> "{treatment}";' | |
| graph += f'"{confounder}" -> "{outcome}";' | |
| # For remaining components (non-confounders), we'll add them as potential causes of the outcome | |
| # but not necessarily related to the treatment | |
| for component in component_cols: | |
| if component not in potential_confounders: | |
| graph += f'"{component}" -> "{outcome}";' | |
| graph += "}" | |
| return graph | |
| def run_dowhy_analysis( | |
| df: pd.DataFrame, | |
| treatment_component: str, | |
| outcome_var: str = "perturbation", | |
| proceed_when_unidentifiable: bool = True | |
| ) -> Dict: | |
| """ | |
| Run causal analysis using DoWhy for a single treatment component. | |
| Args: | |
| df: DataFrame with binary component features and outcome variable | |
| treatment_component: Name of the component to analyze | |
| outcome_var: Name of the outcome variable | |
| proceed_when_unidentifiable: Whether to proceed when effect is unidentifiable | |
| Returns: | |
| Dictionary with causal analysis results | |
| """ | |
| # Ensure the treatment_component is in the expected format | |
| if treatment_component in df.columns: | |
| treatment = treatment_component | |
| else: | |
| logger.error(f"Treatment component {treatment_component} not found in DataFrame") | |
| return {"component": treatment_component, "error": f"Component not found"} | |
| # Check for potential interaction effects with other components | |
| interaction_components = [] | |
| # Look for potential interaction effects | |
| # An interaction effect might be present if two variables together have a different effect | |
| # than the sum of their individual effects | |
| if df[treatment].sum() > 0: # Only check if the treatment appears in the data | |
| # Get other components to check for interactions | |
| other_components = [col for col in df.columns if col.startswith(('entity_', 'relation_')) | |
| and col != treatment and col != outcome_var] | |
| for component in other_components: | |
| # Skip components with no occurrences | |
| if df[component].sum() == 0: | |
| continue | |
| # Check if the component co-occurs with the treatment more than expected by chance | |
| # This is a simplistic approach to identify potential interactions | |
| expected_cooccurrence = (df[treatment].mean() * df[component].mean()) * len(df) | |
| actual_cooccurrence = (df[treatment] & df[component]).sum() | |
| # If actual co-occurrence is significantly different from expected | |
| if actual_cooccurrence > 1.5 * expected_cooccurrence: | |
| interaction_components.append(component) | |
| # Generate a simple causal graph | |
| graph = generate_simple_causal_graph(df, treatment, outcome_var) | |
| # Create the causal model | |
| try: | |
| model = CausalModel( | |
| data=df, | |
| treatment=treatment, | |
| outcome=outcome_var, | |
| graph=graph, | |
| proceed_when_unidentifiable=proceed_when_unidentifiable | |
| ) | |
| # Print the graph (for debugging) | |
| logger.info(f"Causal graph for {treatment}: {graph}") | |
| # Identify the causal effect | |
| identified_estimand = model.identify_effect(proceed_when_unidentifiable=proceed_when_unidentifiable) | |
| logger.info(f"Identified estimand for {treatment}") | |
| # If there's no variance in the outcome, we can't estimate effect | |
| if df[outcome_var].std() == 0: | |
| logger.warning(f"No variance in outcome variable {outcome_var}, skipping estimation") | |
| return { | |
| "component": treatment.replace("comp_", ""), | |
| "identified_estimand": str(identified_estimand), | |
| "error": "No variance in outcome variable" | |
| } | |
| # Estimate the causal effect | |
| try: | |
| estimate = model.estimate_effect( | |
| identified_estimand, | |
| method_name="backdoor.linear_regression", | |
| target_units="ate", | |
| test_significance=None | |
| ) | |
| logger.info(f"Estimated causal effect for {treatment}: {estimate.value}") | |
| # Check for interaction effects if we found potential interaction components | |
| interaction_effects = [] | |
| if interaction_components: | |
| for interaction_component in interaction_components: | |
| # Create interaction term (product of both components) | |
| interaction_col = f"{treatment}_x_{interaction_component}" | |
| df[interaction_col] = df[treatment] * df[interaction_component] | |
| # Run a simple linear regression with the interaction term | |
| X = df[[treatment, interaction_component, interaction_col]] | |
| y = df[outcome_var] | |
| try: | |
| from sklearn.linear_model import LinearRegression | |
| model_with_interaction = LinearRegression() | |
| model_with_interaction.fit(X, y) | |
| # Get the coefficient for the interaction term | |
| interaction_coef = model_with_interaction.coef_[2] # Index 2 is the interaction term | |
| # Store the interaction effect | |
| interaction_effects.append({ | |
| "component": interaction_component, | |
| "interaction_coefficient": float(interaction_coef) | |
| }) | |
| # Clean up temporary column | |
| df.drop(columns=[interaction_col], inplace=True) | |
| except Exception as e: | |
| logger.warning(f"Error analyzing interaction with {interaction_component}: {str(e)}") | |
| # Refute the results | |
| refutation_results = [] | |
| # 1. Random common cause refutation | |
| try: | |
| rcc_refute = model.refute_estimate( | |
| identified_estimand, | |
| estimate, | |
| method_name="random_common_cause" | |
| ) | |
| refutation_results.append({ | |
| "method": "random_common_cause", | |
| "refutation_result": str(rcc_refute) | |
| }) | |
| except Exception as e: | |
| logger.warning(f"Random common cause refutation failed: {str(e)}") | |
| # 2. Placebo treatment refutation | |
| try: | |
| placebo_refute = model.refute_estimate( | |
| identified_estimand, | |
| estimate, | |
| method_name="placebo_treatment_refuter" | |
| ) | |
| refutation_results.append({ | |
| "method": "placebo_treatment", | |
| "refutation_result": str(placebo_refute) | |
| }) | |
| except Exception as e: | |
| logger.warning(f"Placebo treatment refutation failed: {str(e)}") | |
| # 3. Data subset refutation | |
| try: | |
| subset_refute = model.refute_estimate( | |
| identified_estimand, | |
| estimate, | |
| method_name="data_subset_refuter" | |
| ) | |
| refutation_results.append({ | |
| "method": "data_subset", | |
| "refutation_result": str(subset_refute) | |
| }) | |
| except Exception as e: | |
| logger.warning(f"Data subset refutation failed: {str(e)}") | |
| result = { | |
| "component": treatment, | |
| "identified_estimand": str(identified_estimand), | |
| "effect_estimate": float(estimate.value), | |
| "refutation_results": refutation_results | |
| } | |
| # Add interaction effects if found | |
| if interaction_effects: | |
| result["interaction_effects"] = interaction_effects | |
| return result | |
| except Exception as e: | |
| logger.error(f"Error estimating effect for {treatment}: {str(e)}") | |
| return { | |
| "component": treatment, | |
| "identified_estimand": str(identified_estimand), | |
| "error": f"Estimation error: {str(e)}" | |
| } | |
| except Exception as e: | |
| logger.error(f"Error in causal analysis for {treatment}: {str(e)}") | |
| return { | |
| "component": treatment, | |
| "error": str(e) | |
| } | |
| def analyze_components_with_dowhy( | |
| df: pd.DataFrame, | |
| components_to_analyze: List[str] | |
| ) -> List[Dict]: | |
| """ | |
| Analyze causal effects of multiple components using DoWhy. | |
| Args: | |
| df: DataFrame with binary component features and outcome variable | |
| components_to_analyze: List of component names to analyze | |
| Returns: | |
| List of dictionaries with causal analysis results | |
| """ | |
| results = [] | |
| # Track relationships between components for post-processing | |
| interaction_map = defaultdict(list) | |
| confounder_map = defaultdict(list) | |
| # First, analyze each component individually | |
| for component in components_to_analyze: | |
| print(f"\nAnalyzing causal effect of component: {component}") | |
| result = run_dowhy_analysis(df, component) | |
| results.append(result) | |
| # Print result summary | |
| if "error" in result: | |
| print(f" Error: {result['error']}") | |
| else: | |
| print(f" Estimated causal effect: {result.get('effect_estimate', 'N/A')}") | |
| # Track interactions if found | |
| if "interaction_effects" in result: | |
| for interaction in result["interaction_effects"]: | |
| interacting_component = interaction["component"] | |
| interaction_coef = interaction["interaction_coefficient"] | |
| # Record the interaction effect | |
| interaction_entry = { | |
| "component": component, | |
| "interaction_coefficient": interaction_coef | |
| } | |
| interaction_map[interacting_component].append(interaction_entry) | |
| print(f" Interaction with {interacting_component}: {interaction_coef}") | |
| # Post-process to identify components that consistently appear in interactions | |
| # or as confounders | |
| for result in results: | |
| component = result.get("component") | |
| # Skip results with errors | |
| if "error" in result or not component: | |
| continue | |
| # Add interactions information to the result | |
| if component in interaction_map and interaction_map[component]: | |
| result["interacts_with"] = interaction_map[component] | |
| return results | |
| def main(): | |
| """Main function to run the DoWhy causal component analysis.""" | |
| # Set up argument parser | |
| parser = argparse.ArgumentParser(description='DoWhy Causal Component Analysis') | |
| parser.add_argument('--test', action='store_true', help='Enable test mode with mock perturbation scores') | |
| parser.add_argument('--components', nargs='+', help='Component names to test in test mode') | |
| parser.add_argument('--treatments', nargs='+', help='Component names to treat as treatments for causal analysis') | |
| parser.add_argument('--list-components', action='store_true', help='List available components and exit') | |
| parser.add_argument('--base-score', type=float, default=1.0, help='Base perturbation score (default: 1.0)') | |
| parser.add_argument('--treatment-score', type=float, default=0.2, help='Score for test components (default: 0.2)') | |
| parser.add_argument('--json-file', type=str, help='Path to JSON file (default: example.json)') | |
| parser.add_argument('--top-k', type=int, default=5, help='Number of top components to analyze (default: 5)') | |
| args = parser.parse_args() | |
| # Path to example.json file or user-specified file | |
| if args.json_file: | |
| json_file = args.json_file | |
| else: | |
| json_file = os.path.join(os.path.dirname(__file__), 'example.json') | |
| # Create DataFrame using the function from create_component_influence_dataframe.py | |
| df = create_component_influence_dataframe(json_file) | |
| if df is None or df.empty: | |
| logger.error("Failed to create or empty DataFrame. Cannot proceed with analysis.") | |
| return | |
| # List components if requested | |
| if args.list_components: | |
| components = list_available_components(df) | |
| print("\nAvailable components:") | |
| for i, comp in enumerate(components, 1): | |
| print(f"{i}. {comp}") | |
| return | |
| # Create mock perturbation scores if in test mode | |
| if args.test: | |
| if not args.components: | |
| logger.warning("No components specified for test mode. Using random components.") | |
| # Select random components if none specified | |
| all_components = list_available_components(df) | |
| if len(all_components) > 0: | |
| test_components = np.random.choice(all_components, | |
| size=min(2, len(all_components)), | |
| replace=False).tolist() | |
| else: | |
| logger.error("No components found in DataFrame. Cannot create mock scores.") | |
| return | |
| else: | |
| test_components = args.components | |
| print(f"\nTest mode enabled. Using components: {', '.join(test_components)}") | |
| print(f"Setting base score: {args.base_score}, treatment score: {args.treatment_score}") | |
| # Create mock perturbation scores | |
| df = create_mock_perturbation_scores( | |
| df, | |
| test_components, | |
| base_score=args.base_score, | |
| treatment_score=args.treatment_score | |
| ) | |
| # Print basic DataFrame info | |
| print(f"\nDataFrame info:") | |
| print(f"Rows: {len(df)}") | |
| feature_cols = [col for col in df.columns if col.startswith("comp_")] | |
| print(f"Features: {len(feature_cols)}") | |
| print(f"Columns: {', '.join([col for col in df.columns if not col.startswith('comp_')])}") | |
| # Check if we have any variance in perturbation scores | |
| if df['perturbation'].std() == 0: | |
| print("\nWARNING: All perturbation scores are identical (value: %.2f)." % df['perturbation'].iloc[0]) | |
| print(" This will limit the effectiveness of causal analysis.") | |
| print(" Consider using synthetic data with varied perturbation scores for better results.\n") | |
| else: | |
| print(f"\nPerturbation score statistics:") | |
| print(f"Min: {df['perturbation'].min():.2f}") | |
| print(f"Max: {df['perturbation'].max():.2f}") | |
| print(f"Mean: {df['perturbation'].mean():.2f}") | |
| print(f"Std: {df['perturbation'].std():.2f}") | |
| # Determine components to analyze | |
| if args.treatments: | |
| components_to_analyze = args.treatments | |
| else: | |
| # Default to top-k components | |
| components_to_analyze = list_available_components(df)[:args.top_k] | |
| print(f"\nAnalyzing {len(components_to_analyze)} components as treatments: {', '.join(components_to_analyze)}") | |
| # Run DoWhy causal analysis for each treatment component | |
| results = analyze_components_with_dowhy(df, components_to_analyze) | |
| # Save results to JSON file | |
| output_filename = 'dowhy_causal_effects.json' | |
| if args.test: | |
| output_filename = 'test_dowhy_causal_effects.json' | |
| output_path = os.path.join(os.path.dirname(__file__), output_filename) | |
| try: | |
| with open(output_path, 'w') as f: | |
| json.dump({ | |
| "metadata": { | |
| "json_file": json_file, | |
| "test_mode": args.test, | |
| "components_analyzed": components_to_analyze, | |
| }, | |
| "results": results | |
| }, f, indent=2) | |
| logger.info(f"Causal analysis results saved to {output_path}") | |
| print(f"\nCausal analysis complete. Results saved to {output_path}") | |
| except Exception as e: | |
| logger.error(f"Error saving results to {output_path}: {str(e)}") | |
| print(f"\nError saving results: {str(e)}") | |
| if __name__ == "__main__": | |
| main() |