Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Causal Component Analysis | |
| This script implements causal inference methods to analyze the causal relationship | |
| between knowledge graph components and perturbation scores. | |
| """ | |
| import os | |
| import sys | |
| import pandas as pd | |
| import numpy as np | |
| import logging | |
| import argparse | |
| from typing import Dict, List, Optional, Tuple, Set | |
| from sklearn.linear_model import LinearRegression | |
| # Import from utils directory | |
| from .utils.dataframe_builder import create_component_influence_dataframe | |
| # Import shared utilities | |
| from .utils.shared_utils import list_available_components | |
| # Configure logging for this module | |
| logger = logging.getLogger(__name__) | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| def calculate_average_treatment_effect( | |
| df: pd.DataFrame, | |
| component_id: str, | |
| outcome_var: str = "perturbation", | |
| control_vars: Optional[List[str]] = None | |
| ) -> Dict[str, float]: | |
| """ | |
| Calculates the Average Treatment Effect (ATE) of a component on perturbation score. | |
| Args: | |
| df: DataFrame with binary component features and perturbation score | |
| component_id: ID of the component to analyze (including 'entity_' or 'relation_' prefix) | |
| outcome_var: Name of the outcome variable (default: 'perturbation') | |
| control_vars: List of control variables to include in the model (other components) | |
| Returns: | |
| Dictionary with ATE estimates and confidence intervals | |
| """ | |
| if component_id not in df.columns: | |
| logger.error(f"Component {component_id} not found in DataFrame") | |
| return { | |
| "ate": 0.0, | |
| "std_error": 0.0, | |
| "p_value": 1.0, | |
| "confidence_interval_95": (0.0, 0.0) | |
| } | |
| # Check if there's enough variation in the treatment variable | |
| if df[component_id].std() == 0: | |
| logger.warning(f"No variation in component {component_id}, cannot estimate causal effect") | |
| return { | |
| "ate": 0.0, | |
| "std_error": 0.0, | |
| "p_value": 1.0, | |
| "confidence_interval_95": (0.0, 0.0) | |
| } | |
| # Check if there's enough variation in the outcome variable | |
| if df[outcome_var].std() == 0: | |
| logger.warning(f"No variation in outcome {outcome_var}, cannot estimate causal effect") | |
| return { | |
| "ate": 0.0, | |
| "std_error": 0.0, | |
| "p_value": 1.0, | |
| "confidence_interval_95": (0.0, 0.0) | |
| } | |
| # Select control variables (other components that could confound the relationship) | |
| if control_vars is None: | |
| # Use all other components as control variables | |
| control_vars = [col for col in df.columns if (col.startswith("entity_") or col.startswith("relation_")) and col != component_id] | |
| # Create treatment and control groups | |
| treatment_group = df[df[component_id] == 1] | |
| control_group = df[df[component_id] == 0] | |
| # Calculate naive ATE (without controlling for confounders) | |
| naive_ate = treatment_group[outcome_var].mean() - control_group[outcome_var].mean() | |
| # Implement regression adjustment to control for confounders | |
| X = df[control_vars + [component_id]] | |
| y = df[outcome_var] | |
| # Use linear regression for adjustment | |
| model = LinearRegression() | |
| model.fit(X, y) | |
| # Extract coefficient for the component of interest (the ATE) | |
| component_idx = control_vars.index(component_id) if component_id in control_vars else -1 | |
| ate = model.coef_[component_idx] | |
| # Use bootstrapping to calculate standard errors and confidence intervals | |
| # Simplified implementation for demonstration | |
| n_bootstrap = 1000 | |
| bootstrap_ates = [] | |
| for _ in range(n_bootstrap): | |
| # Sample with replacement | |
| sample_idx = np.random.choice(len(df), len(df), replace=True) | |
| sample_df = df.iloc[sample_idx] | |
| # Calculate ATE for this sample | |
| sample_X = sample_df[control_vars + [component_id]] | |
| sample_y = sample_df[outcome_var] | |
| try: | |
| sample_model = LinearRegression() | |
| sample_model.fit(sample_X, sample_y) | |
| sample_ate = sample_model.coef_[component_idx] | |
| bootstrap_ates.append(sample_ate) | |
| except: | |
| # Skip problematic samples | |
| continue | |
| # Calculate standard error and confidence intervals | |
| std_error = np.std(bootstrap_ates) | |
| ci_lower = np.percentile(bootstrap_ates, 2.5) | |
| ci_upper = np.percentile(bootstrap_ates, 97.5) | |
| # Calculate p-value (simplified approach) | |
| z_score = ate / std_error if std_error > 0 else 0 | |
| p_value = 2 * (1 - abs(z_score)) if z_score != 0 else 1.0 | |
| return { | |
| "ate": ate, | |
| "naive_ate": naive_ate, | |
| "std_error": std_error, | |
| "p_value": p_value, | |
| "confidence_interval_95": (ci_lower, ci_upper) | |
| } | |
| def granger_causality_test( | |
| df: pd.DataFrame, | |
| component_id: str, | |
| outcome_var: str = "perturbation", | |
| max_lag: int = 2 | |
| ) -> Dict[str, float]: | |
| """ | |
| Implements a simplified Granger causality test to assess if a component | |
| 'Granger-causes' the perturbation score. | |
| Note: This is a simplified implementation and requires time-series data. | |
| If the data doesn't have a clear time dimension, the results should be | |
| interpreted with caution. | |
| Args: | |
| df: DataFrame with binary component features and perturbation score | |
| component_id: ID of the component to analyze (including 'entity_' or 'relation_' prefix) | |
| outcome_var: Name of the outcome variable (default: 'perturbation') | |
| max_lag: Maximum number of lags to include in the model | |
| Returns: | |
| Dictionary with Granger causality test results | |
| """ | |
| if component_id not in df.columns: | |
| logger.error(f"Component {component_id} not found in DataFrame") | |
| return {"f_statistic": 0.0, "p_value": 1.0, "causal_direction": "none"} | |
| # Check if there's enough data points | |
| if len(df) <= max_lag + 1: | |
| logger.warning(f"Not enough data points for Granger causality test with max_lag={max_lag}") | |
| return {"f_statistic": 0.0, "p_value": 1.0, "causal_direction": "none"} | |
| # Check if there's enough variation in the variables | |
| if df[component_id].std() == 0 or df[outcome_var].std() == 0: | |
| logger.warning(f"No variation in component or outcome, cannot test Granger causality") | |
| return {"f_statistic": 0.0, "p_value": 1.0, "causal_direction": "none"} | |
| # Implement Granger causality test using OLS and F-test | |
| # This is a simplified approach - in practice, use statsmodels or other libraries | |
| # First, create lagged versions of the data | |
| lagged_df = df.copy() | |
| for i in range(1, max_lag + 1): | |
| lagged_df[f"{component_id}_lag{i}"] = df[component_id].shift(i) | |
| lagged_df[f"{outcome_var}_lag{i}"] = df[outcome_var].shift(i) | |
| # Drop rows with NaN values (due to lagging) | |
| lagged_df = lagged_df.dropna() | |
| # Model 1: Outcome ~ Past Outcomes | |
| X1 = lagged_df[[f"{outcome_var}_lag{i}" for i in range(1, max_lag + 1)]] | |
| y = lagged_df[outcome_var] | |
| model1 = LinearRegression() | |
| model1.fit(X1, y) | |
| y_pred1 = model1.predict(X1) | |
| ssr1 = np.sum((y - y_pred1) ** 2) | |
| # Model 2: Outcome ~ Past Outcomes + Past Component | |
| X2 = lagged_df[[f"{outcome_var}_lag{i}" for i in range(1, max_lag + 1)] + | |
| [f"{component_id}_lag{i}" for i in range(1, max_lag + 1)]] | |
| model2 = LinearRegression() | |
| model2.fit(X2, y) | |
| y_pred2 = model2.predict(X2) | |
| ssr2 = np.sum((y - y_pred2) ** 2) | |
| # Calculate F-statistic | |
| n = len(lagged_df) | |
| df1 = max_lag | |
| df2 = n - 2 * max_lag - 1 | |
| if ssr1 == 0 or df2 <= 0: | |
| f_statistic = 0 | |
| p_value = 1.0 | |
| else: | |
| f_statistic = ((ssr1 - ssr2) / df1) / (ssr2 / df2) | |
| # Simplified p-value calculation (for demonstration) | |
| p_value = 1 / (1 + f_statistic) | |
| # Test reverse causality | |
| # Model 3: Component ~ Past Components | |
| X3 = lagged_df[[f"{component_id}_lag{i}" for i in range(1, max_lag + 1)]] | |
| y_comp = lagged_df[component_id] | |
| model3 = LinearRegression() | |
| model3.fit(X3, y_comp) | |
| y_pred3 = model3.predict(X3) | |
| ssr3 = np.sum((y_comp - y_pred3) ** 2) | |
| # Model 4: Component ~ Past Components + Past Outcomes | |
| X4 = lagged_df[[f"{component_id}_lag{i}" for i in range(1, max_lag + 1)] + | |
| [f"{outcome_var}_lag{i}" for i in range(1, max_lag + 1)]] | |
| model4 = LinearRegression() | |
| model4.fit(X4, y_comp) | |
| y_pred4 = model4.predict(X4) | |
| ssr4 = np.sum((y_comp - y_pred4) ** 2) | |
| # Calculate F-statistic for reverse causality | |
| if ssr3 == 0 or df2 <= 0: | |
| f_statistic_reverse = 0 | |
| p_value_reverse = 1.0 | |
| else: | |
| f_statistic_reverse = ((ssr3 - ssr4) / df1) / (ssr4 / df2) | |
| # Simplified p-value calculation | |
| p_value_reverse = 1 / (1 + f_statistic_reverse) | |
| # Determine causality direction | |
| causal_direction = "none" | |
| if p_value < 0.05 and p_value_reverse >= 0.05: | |
| causal_direction = "component -> outcome" | |
| elif p_value >= 0.05 and p_value_reverse < 0.05: | |
| causal_direction = "outcome -> component" | |
| elif p_value < 0.05 and p_value_reverse < 0.05: | |
| causal_direction = "bidirectional" | |
| return { | |
| "f_statistic": f_statistic, | |
| "p_value": p_value, | |
| "f_statistic_reverse": f_statistic_reverse, | |
| "p_value_reverse": p_value_reverse, | |
| "causal_direction": causal_direction | |
| } | |
| def compute_causal_effect_strength( | |
| df: pd.DataFrame, | |
| control_group: Optional[List[str]] = None, | |
| outcome_var: str = "perturbation" | |
| ) -> Dict[str, float]: | |
| """ | |
| Computes the strength of causal effects for all components. | |
| Args: | |
| df: DataFrame with binary component features and perturbation score | |
| control_group: List of components to use as control variables | |
| outcome_var: Name of the outcome variable (default: 'perturbation') | |
| Returns: | |
| Dictionary mapping component IDs to their causal effect strengths | |
| """ | |
| # Get all component columns | |
| component_cols = [col for col in df.columns if col.startswith(("entity_", "relation_"))] | |
| if not component_cols: | |
| logger.error("No component features found in DataFrame") | |
| return {} | |
| # Calculate ATE for each component | |
| effect_strengths = {} | |
| for component_id in component_cols: | |
| try: | |
| ate_results = calculate_average_treatment_effect( | |
| df, | |
| component_id, | |
| outcome_var=outcome_var, | |
| control_vars=control_group | |
| ) | |
| effect_strengths[component_id] = ate_results["ate"] | |
| except Exception as e: | |
| logger.warning(f"Error calculating ATE for {component_id}: {e}") | |
| effect_strengths[component_id] = 0.0 | |
| return effect_strengths | |
| # Note: create_mock_perturbation_scores and list_available_components | |
| # moved to utils.shared_utils to avoid duplication | |
| def main(): | |
| """Main function to run the causal component analysis.""" | |
| parser = argparse.ArgumentParser(description='Analyze causal relationships between components and perturbation scores') | |
| parser.add_argument('--input', '-i', required=True, help='Path to the knowledge graph JSON file') | |
| parser.add_argument('--output', '-o', help='Path to save the output analysis (CSV format)') | |
| args = parser.parse_args() | |
| print(f"Loading knowledge graph") | |
| # Create DataFrame | |
| df = create_component_influence_dataframe(args.input) | |
| if df is None or df.empty: | |
| logger.error("Failed to create or empty DataFrame. Cannot proceed with analysis.") | |
| return | |
| # Print basic DataFrame info | |
| print(f"\nDataFrame info:") | |
| print(f"Rows: {len(df)}") | |
| entity_features = [col for col in df.columns if col.startswith("entity_")] | |
| relation_features = [col for col in df.columns if col.startswith("relation_")] | |
| print(f"Entity features: {len(entity_features)}") | |
| print(f"Relation features: {len(relation_features)}") | |
| # Check if we have any variance in perturbation scores | |
| if df['perturbation'].std() == 0: | |
| logger.warning("All perturbation scores are identical. This might lead to uninformative results.") | |
| print("\nWARNING: All perturbation scores are identical (value: %.2f). Results may not be meaningful." % df['perturbation'].iloc[0]) | |
| else: | |
| print(f"\nPerturbation score distribution:") | |
| print(f"Min: {df['perturbation'].min():.2f}, Max: {df['perturbation'].max():.2f}") | |
| print(f"Mean: {df['perturbation'].mean():.2f}, Std: {df['perturbation'].std():.2f}") | |
| # Compute causal effect strengths | |
| print("\nComputing causal effect strengths...") | |
| effect_strengths = compute_causal_effect_strength(df) | |
| print(f"Found {len(effect_strengths)} components with causal effects") | |
| # Sort components by effect strength | |
| sorted_components = sorted(effect_strengths.items(), key=lambda x: abs(x[1]), reverse=True) | |
| print("\nTop 10 Components by Causal Effect Strength:") | |
| print("=" * 50) | |
| print(f"{'Rank':<5}{'Component':<30}{'Effect Strength':<15}") | |
| print("-" * 50) | |
| for i, (component, strength) in enumerate(sorted_components[:10], 1): | |
| print(f"{i:<5}{component:<30}{strength:.6f}") | |
| # Save results | |
| if args.output: | |
| # Create results DataFrame | |
| results_df = pd.DataFrame({ | |
| 'Component': [comp for comp, _ in sorted_components], | |
| 'Effect_Strength': [strength for _, strength in sorted_components] | |
| }) | |
| # Save to specified output path | |
| print(f"\nSaving results to: {args.output}") | |
| try: | |
| results_df.to_csv(args.output, index=False) | |
| print(f"Successfully saved results to: {args.output}") | |
| except Exception as e: | |
| print(f"Error saving to {args.output}: {str(e)}") | |
| # Also save to default location in the causal_analysis directory | |
| default_output = os.path.join(os.path.dirname(__file__), 'causal_component_effects.csv') | |
| print(f"Also saving results to: {default_output}") | |
| try: | |
| results_df.to_csv(default_output, index=False) | |
| print(f"Successfully saved results to: {default_output}") | |
| except Exception as e: | |
| print(f"Error saving to {default_output}: {str(e)}") | |
| print("\nAnalysis complete.") | |
| if __name__ == "__main__": | |
| main() |