File size: 19,713 Bytes
c2ea5ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
#!/usr/bin/env python3
"""
DoWhy Causal Component Analysis

This script implements causal inference methods using the DoWhy library to analyze 
the causal relationship between knowledge graph components and perturbation scores.
"""

import os
import sys
import pandas as pd
import numpy as np
import argparse
import logging
import json
from typing import Dict, List, Optional, Tuple, Set
from collections import defaultdict

# Import DoWhy
import dowhy
from dowhy import CausalModel

# Import from utils directory
from .utils.dataframe_builder import create_component_influence_dataframe
# Import shared utilities
from .utils.shared_utils import create_mock_perturbation_scores, list_available_components

# Configure logging
logger = logging.getLogger(__name__)
# Suppress DoWhy/info logs by setting their loggers to WARNING or higher
logging.basicConfig(level=logging.CRITICAL, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')

# Suppress DoWhy and related noisy loggers
for noisy_logger in [
    "dowhy", 
    "dowhy.causal_estimator", 
    "dowhy.causal_model", 
    "dowhy.causal_refuter", 
    "dowhy.do_sampler", 
    "dowhy.identifier", 
    "dowhy.propensity_score", 
    "dowhy.utils", 
    "dowhy.causal_refuter.add_unobserved_common_cause"
]:
    logging.getLogger(noisy_logger).setLevel(logging.WARNING)

# Note: create_mock_perturbation_scores and list_available_components 
# moved to utils.shared_utils to avoid duplication

def generate_simple_causal_graph(df: pd.DataFrame, treatment: str, outcome: str) -> str:
    """
    Generate a simple causal graph in a format compatible with DoWhy.
    
    Args:
        df: DataFrame with features
        treatment: Treatment variable name
        outcome: Outcome variable name
        
    Returns:
        String representation of the causal graph in DoWhy format
    """
    # Get component columns (all other variables that could affect both treatment and outcome)
    component_cols = [col for col in df.columns if col.startswith(('entity_', 'relation_')) and col != treatment]
    
    # Identify potential confounders by checking correlation patterns with the treatment
    confounder_threshold = 0.7  # Correlation threshold to identify potential confounders
    potential_confounders = []
    
    # Calculate correlations between components to identify potential confounders
    # A high correlation may indicate a confounder relationship
    for component in component_cols:
        # Skip if no variance (would result in correlation NaN)
        if df[component].std() == 0 or df[treatment].std() == 0:
            continue
        
        correlation = df[component].corr(df[treatment])
        if abs(correlation) >= confounder_threshold:
            potential_confounders.append(component)
    
    # Create a graph in DOT format
    graph = "digraph {"
    
    # Add edges for Treatment -> Outcome
    graph += f'"{treatment}" -> "{outcome}";'
    
    # Add edges for identified confounders
    for confounder in potential_confounders:
        # Confounder affects both treatment and outcome
        graph += f'"{confounder}" -> "{treatment}";'
        graph += f'"{confounder}" -> "{outcome}";'
    
    # For remaining components (non-confounders), we'll add them as potential causes of the outcome
    # but not necessarily related to the treatment
    for component in component_cols:
        if component not in potential_confounders:
            graph += f'"{component}" -> "{outcome}";'
    
    graph += "}"
    
    return graph

def run_dowhy_analysis(
    df: pd.DataFrame,
    treatment_component: str,
    outcome_var: str = "perturbation",
    proceed_when_unidentifiable: bool = True
) -> Dict:
    """
    Run causal analysis using DoWhy for a single treatment component.
    
    Args:
        df: DataFrame with binary component features and outcome variable
        treatment_component: Name of the component to analyze
        outcome_var: Name of the outcome variable
        proceed_when_unidentifiable: Whether to proceed when effect is unidentifiable
        
    Returns:
        Dictionary with causal analysis results
    """
    # Ensure the treatment_component is in the expected format
    if treatment_component in df.columns:
        treatment = treatment_component
    else:
        logger.error(f"Treatment component {treatment_component} not found in DataFrame")
        return {"component": treatment_component, "error": f"Component not found"}
    
    # Check for potential interaction effects with other components
    interaction_components = []
    
    # Look for potential interaction effects
    # An interaction effect might be present if two variables together have a different effect
    # than the sum of their individual effects
    if df[treatment].sum() > 0:  # Only check if the treatment appears in the data
        # Get other components to check for interactions
        other_components = [col for col in df.columns if col.startswith(('entity_', 'relation_')) 
                           and col != treatment and col != outcome_var]
        
        for component in other_components:
            # Skip components with no occurrences
            if df[component].sum() == 0:
                continue
                
            # Check if the component co-occurs with the treatment more than expected by chance
            # This is a simplistic approach to identify potential interactions
            expected_cooccurrence = (df[treatment].mean() * df[component].mean()) * len(df)
            actual_cooccurrence = (df[treatment] & df[component]).sum()
            
            # If actual co-occurrence is significantly different from expected
            if actual_cooccurrence > 1.5 * expected_cooccurrence:
                interaction_components.append(component)
    
    # Generate a simple causal graph
    graph = generate_simple_causal_graph(df, treatment, outcome_var)
    
    # Create the causal model
    try:
        model = CausalModel(
            data=df,
            treatment=treatment,
            outcome=outcome_var,
            graph=graph,
            proceed_when_unidentifiable=proceed_when_unidentifiable
        )
        
        # Print the graph (for debugging)
        logger.info(f"Causal graph for {treatment}: {graph}")
        
        # Identify the causal effect
        identified_estimand = model.identify_effect(proceed_when_unidentifiable=proceed_when_unidentifiable)
        logger.info(f"Identified estimand for {treatment}")
        
        # If there's no variance in the outcome, we can't estimate effect
        if df[outcome_var].std() == 0:
            logger.warning(f"No variance in outcome variable {outcome_var}, skipping estimation")
            return {
                "component": treatment.replace("comp_", ""),
                "identified_estimand": str(identified_estimand),
                "error": "No variance in outcome variable"
            }
        
        # Estimate the causal effect
        try:
            estimate = model.estimate_effect(
                identified_estimand,
                method_name="backdoor.linear_regression",
                target_units="ate",
                test_significance=None
            )
            logger.info(f"Estimated causal effect for {treatment}: {estimate.value}")
            
            # Check for interaction effects if we found potential interaction components
            interaction_effects = []
            if interaction_components:
                for interaction_component in interaction_components:
                    # Create interaction term (product of both components)
                    interaction_col = f"{treatment}_x_{interaction_component}"
                    df[interaction_col] = df[treatment] * df[interaction_component]
                    
                    # Run a simple linear regression with the interaction term
                    X = df[[treatment, interaction_component, interaction_col]]
                    y = df[outcome_var]
                    
                    try:
                        from sklearn.linear_model import LinearRegression
                        model_with_interaction = LinearRegression()
                        model_with_interaction.fit(X, y)
                        
                        # Get the coefficient for the interaction term
                        interaction_coef = model_with_interaction.coef_[2]  # Index 2 is the interaction term
                        
                        # Store the interaction effect
                        interaction_effects.append({
                            "component": interaction_component,
                            "interaction_coefficient": float(interaction_coef)
                        })
                        
                        # Clean up temporary column
                        df.drop(columns=[interaction_col], inplace=True)
                    except Exception as e:
                        logger.warning(f"Error analyzing interaction with {interaction_component}: {str(e)}")
            
            # Refute the results
            refutation_results = []
            
            # 1. Random common cause refutation
            try:
                rcc_refute = model.refute_estimate(
                    identified_estimand, 
                    estimate,
                    method_name="random_common_cause"
                )
                refutation_results.append({
                    "method": "random_common_cause",
                    "refutation_result": str(rcc_refute)
                })
            except Exception as e:
                logger.warning(f"Random common cause refutation failed: {str(e)}")
            
            # 2. Placebo treatment refutation
            try:
                placebo_refute = model.refute_estimate(
                    identified_estimand, 
                    estimate,
                    method_name="placebo_treatment_refuter"
                )
                refutation_results.append({
                    "method": "placebo_treatment",
                    "refutation_result": str(placebo_refute)
                })
            except Exception as e:
                logger.warning(f"Placebo treatment refutation failed: {str(e)}")
            
            # 3. Data subset refutation
            try:
                subset_refute = model.refute_estimate(
                    identified_estimand, 
                    estimate,
                    method_name="data_subset_refuter"
                )
                refutation_results.append({
                    "method": "data_subset",
                    "refutation_result": str(subset_refute)
                })
            except Exception as e:
                logger.warning(f"Data subset refutation failed: {str(e)}")
            
            result = {
                "component": treatment,
                "identified_estimand": str(identified_estimand),
                "effect_estimate": float(estimate.value),
                "refutation_results": refutation_results
            }
            
            # Add interaction effects if found
            if interaction_effects:
                result["interaction_effects"] = interaction_effects
                
            return result
            
        except Exception as e:
            logger.error(f"Error estimating effect for {treatment}: {str(e)}")
            return {
                "component": treatment,
                "identified_estimand": str(identified_estimand),
                "error": f"Estimation error: {str(e)}"
            }
            
    except Exception as e:
        logger.error(f"Error in causal analysis for {treatment}: {str(e)}")
        return {
            "component": treatment,
            "error": str(e)
        }

def analyze_components_with_dowhy(
    df: pd.DataFrame,
    components_to_analyze: List[str]
) -> List[Dict]:
    """
    Analyze causal effects of multiple components using DoWhy.
    
    Args:
        df: DataFrame with binary component features and outcome variable
        components_to_analyze: List of component names to analyze
        
    Returns:
        List of dictionaries with causal analysis results
    """
    results = []
    
    # Track relationships between components for post-processing
    interaction_map = defaultdict(list)
    confounder_map = defaultdict(list)
    
    # First, analyze each component individually
    for component in components_to_analyze:
        print(f"\nAnalyzing causal effect of component: {component}")
        result = run_dowhy_analysis(df, component)
        results.append(result)
        
        # Print result summary
        if "error" in result:
            print(f"  Error: {result['error']}")
        else:
            print(f"  Estimated causal effect: {result.get('effect_estimate', 'N/A')}")
            
            # Track interactions if found
            if "interaction_effects" in result:
                for interaction in result["interaction_effects"]:
                    interacting_component = interaction["component"]
                    interaction_coef = interaction["interaction_coefficient"]
                    
                    # Record the interaction effect
                    interaction_entry = {
                        "component": component,
                        "interaction_coefficient": interaction_coef
                    }
                    interaction_map[interacting_component].append(interaction_entry)
                    
                    print(f"  Interaction with {interacting_component}: {interaction_coef}")
    
    # Post-process to identify components that consistently appear in interactions
    # or as confounders
    for result in results:
        component = result.get("component")
        
        # Skip results with errors
        if "error" in result or not component:
            continue
            
        # Add interactions information to the result
        if component in interaction_map and interaction_map[component]:
            result["interacts_with"] = interaction_map[component]
    
    return results

def main():
    """Main function to run the DoWhy causal component analysis."""
    # Set up argument parser
    parser = argparse.ArgumentParser(description='DoWhy Causal Component Analysis')
    parser.add_argument('--test', action='store_true', help='Enable test mode with mock perturbation scores')
    parser.add_argument('--components', nargs='+', help='Component names to test in test mode')
    parser.add_argument('--treatments', nargs='+', help='Component names to treat as treatments for causal analysis')
    parser.add_argument('--list-components', action='store_true', help='List available components and exit')
    parser.add_argument('--base-score', type=float, default=1.0, help='Base perturbation score (default: 1.0)')
    parser.add_argument('--treatment-score', type=float, default=0.2, help='Score for test components (default: 0.2)')
    parser.add_argument('--json-file', type=str, help='Path to JSON file (default: example.json)')
    parser.add_argument('--top-k', type=int, default=5, help='Number of top components to analyze (default: 5)')
    args = parser.parse_args()
    
    # Path to example.json file or user-specified file
    if args.json_file:
        json_file = args.json_file
    else:
        json_file = os.path.join(os.path.dirname(__file__), 'example.json')
    
    # Create DataFrame using the function from create_component_influence_dataframe.py
    df = create_component_influence_dataframe(json_file)
    
    if df is None or df.empty:
        logger.error("Failed to create or empty DataFrame. Cannot proceed with analysis.")
        return
    
    # List components if requested
    if args.list_components:
        components = list_available_components(df)
        print("\nAvailable components:")
        for i, comp in enumerate(components, 1):
            print(f"{i}. {comp}")
        return
    
    # Create mock perturbation scores if in test mode
    if args.test:
        if not args.components:
            logger.warning("No components specified for test mode. Using random components.")
            # Select random components if none specified
            all_components = list_available_components(df)
            if len(all_components) > 0:
                test_components = np.random.choice(all_components, 
                                                  size=min(2, len(all_components)), 
                                                  replace=False).tolist()
            else:
                logger.error("No components found in DataFrame. Cannot create mock scores.")
                return
        else:
            test_components = args.components
        
        print(f"\nTest mode enabled. Using components: {', '.join(test_components)}")
        print(f"Setting base score: {args.base_score}, treatment score: {args.treatment_score}")
        
        # Create mock perturbation scores
        df = create_mock_perturbation_scores(
            df, 
            test_components, 
            base_score=args.base_score, 
            treatment_score=args.treatment_score
        )
    
    # Print basic DataFrame info
    print(f"\nDataFrame info:")
    print(f"Rows: {len(df)}")
    feature_cols = [col for col in df.columns if col.startswith("comp_")]
    print(f"Features: {len(feature_cols)}")
    print(f"Columns: {', '.join([col for col in df.columns if not col.startswith('comp_')])}")
    
    # Check if we have any variance in perturbation scores
    if df['perturbation'].std() == 0:
        print("\nWARNING: All perturbation scores are identical (value: %.2f)." % df['perturbation'].iloc[0])
        print("         This will limit the effectiveness of causal analysis.")
        print("         Consider using synthetic data with varied perturbation scores for better results.\n")
    else:
        print(f"\nPerturbation score statistics:")
        print(f"Min: {df['perturbation'].min():.2f}")
        print(f"Max: {df['perturbation'].max():.2f}")
        print(f"Mean: {df['perturbation'].mean():.2f}")
        print(f"Std: {df['perturbation'].std():.2f}")
    
    # Determine components to analyze
    if args.treatments:
        components_to_analyze = args.treatments
    else:
        # Default to top-k components
        components_to_analyze = list_available_components(df)[:args.top_k]
    
    print(f"\nAnalyzing {len(components_to_analyze)} components as treatments: {', '.join(components_to_analyze)}")
    
    # Run DoWhy causal analysis for each treatment component
    results = analyze_components_with_dowhy(df, components_to_analyze)
    
    # Save results to JSON file
    output_filename = 'dowhy_causal_effects.json'
    if args.test:
        output_filename = 'test_dowhy_causal_effects.json'
    
    output_path = os.path.join(os.path.dirname(__file__), output_filename)
    try:
        with open(output_path, 'w') as f:
            json.dump({
                "metadata": {
                    "json_file": json_file,
                    "test_mode": args.test,
                    "components_analyzed": components_to_analyze,
                },
                "results": results
            }, f, indent=2)
        logger.info(f"Causal analysis results saved to {output_path}")
        print(f"\nCausal analysis complete. Results saved to {output_path}")
    except Exception as e:
        logger.error(f"Error saving results to {output_path}: {str(e)}")
        print(f"\nError saving results: {str(e)}")
    
if __name__ == "__main__":
    main()