Spaces:
Running
Running
| """ | |
| decision tree component for selecting causal inference methods | |
| this module implements the decision tree logic to select the most appropriate | |
| causal inference method based on dataset characteristics and available variables | |
| """ | |
| import logging | |
| from typing import Dict, List, Any, Optional | |
| import pandas as pd | |
| # define method names | |
| BACKDOOR_ADJUSTMENT = "backdoor_adjustment" | |
| LINEAR_REGRESSION = "linear_regression" | |
| DIFF_IN_MEANS = "diff_in_means" | |
| DIFF_IN_DIFF = "difference_in_differences" | |
| REGRESSION_DISCONTINUITY = "regression_discontinuity_design" | |
| PROPENSITY_SCORE_MATCHING = "propensity_score_matching" | |
| INSTRUMENTAL_VARIABLE = "instrumental_variable" | |
| CORRELATION_ANALYSIS = "correlation_analysis" | |
| PROPENSITY_SCORE_WEIGHTING = "propensity_score_weighting" | |
| GENERALIZED_PROPENSITY_SCORE = "generalized_propensity_score" | |
| FRONTDOOR_ADJUSTMENT = "frontdoor_adjustment" | |
| logger = logging.getLogger(__name__) | |
| # method assumptions mapping | |
| METHOD_ASSUMPTIONS = { | |
| BACKDOOR_ADJUSTMENT: [ | |
| "no unmeasured confounders (conditional ignorability given covariates)", | |
| "correct model specification for outcome conditional on treatment and covariates", | |
| "positivity/overlap (for all covariate values, units could potentially receive either treatment level)" | |
| ], | |
| LINEAR_REGRESSION: [ | |
| "linear relationship between treatment, covariates, and outcome", | |
| "no unmeasured confounders (if observational)", | |
| "correct model specification", | |
| "homoscedasticity of errors", | |
| "normally distributed errors (for inference)" | |
| ], | |
| DIFF_IN_MEANS: [ | |
| "treatment is randomly assigned (or as-if random)", | |
| "no spillover effects", | |
| "stable unit treatment value assumption (SUTVA)" | |
| ], | |
| DIFF_IN_DIFF: [ | |
| "parallel trends between treatment and control groups before treatment", | |
| "no spillover effects between groups", | |
| "no anticipation effects before treatment", | |
| "stable composition of treatment and control groups", | |
| "treatment timing is exogenous" | |
| ], | |
| REGRESSION_DISCONTINUITY: [ | |
| "units cannot precisely manipulate the running variable around the cutoff", | |
| "continuity of conditional expectation functions of potential outcomes at the cutoff", | |
| "no other changes occurring precisely at the cutoff" | |
| ], | |
| PROPENSITY_SCORE_MATCHING: [ | |
| "no unmeasured confounders (conditional ignorability)", | |
| "sufficient overlap (common support) between treatment and control groups", | |
| "correct propensity score model specification" | |
| ], | |
| INSTRUMENTAL_VARIABLE: [ | |
| "instrument is correlated with treatment (relevance)", | |
| "instrument affects outcome only through treatment (exclusion restriction)", | |
| "instrument is independent of unmeasured confounders (exogeneity/independence)" | |
| ], | |
| CORRELATION_ANALYSIS: [ | |
| "data represents a sample from the population of interest", | |
| "variables are measured appropriately" | |
| ], | |
| PROPENSITY_SCORE_WEIGHTING: [ | |
| "no unmeasured confounders (conditional ignorability)", | |
| "sufficient overlap (common support) between treatment and control groups", | |
| "correct propensity score model specification", | |
| "weights correctly specified (e.g., ATE, ATT)" | |
| ], | |
| GENERALIZED_PROPENSITY_SCORE: [ | |
| "conditional mean independence", | |
| "positivity/common support for GPS", | |
| "correct specification of the GPS model", | |
| "correct specification of the outcome model", | |
| "no unmeasured confounders affecting both treatment and outcome, given X", | |
| "treatment variable is continuous" | |
| ], | |
| FRONTDOOR_ADJUSTMENT: [ | |
| "mediator is affected by treatment and affects outcome", | |
| "mediator is not affected by any confounders of the treatment-outcome relationship" | |
| ] | |
| } | |
| def select_method(dataset_properties: Dict[str, Any], excluded_methods: Optional[List[str]] = None) -> Dict[str, Any]: | |
| excluded_methods = set(excluded_methods or []) | |
| logger.info(f"Excluded methods: {sorted(excluded_methods)}") | |
| treatment = dataset_properties.get("treatment_variable") | |
| outcome = dataset_properties.get("outcome_variable") | |
| if not treatment or not outcome: | |
| raise ValueError("Both treatment and outcome variables must be specified") | |
| instrument_var = dataset_properties.get("instrument_variable") | |
| running_var = dataset_properties.get("running_variable") | |
| cutoff_val = dataset_properties.get("cutoff_value") | |
| time_var = dataset_properties.get("time_variable") | |
| is_rct = dataset_properties.get("is_rct", False) | |
| has_temporal = dataset_properties.get("has_temporal_structure", False) | |
| frontdoor = dataset_properties.get("frontdoor_criterion", False) | |
| covariate_overlap_result = dataset_properties.get("covariate_overlap_score") | |
| covariates = dataset_properties.get("covariates", []) | |
| treatment_variable_type = dataset_properties.get("treatment_variable_type", "binary") | |
| # Helpers to collect candidates | |
| candidates = [] # list of (method, priority_index) | |
| justifications: Dict[str, str] = {} | |
| assumptions: Dict[str, List[str]] = {} | |
| def add(method: str, justification: str, prio_order: List[str]): | |
| if method in justifications: # already added | |
| return | |
| justifications[method] = justification | |
| assumptions[method] = METHOD_ASSUMPTIONS[method] | |
| # priority index from provided order (fallback large if not present) | |
| try: | |
| idx = prio_order.index(method) | |
| except ValueError: | |
| idx = 10**6 | |
| candidates.append((method, idx)) | |
| # ----- Build candidate set (no returns here) ----- | |
| # RCT branch | |
| if is_rct: | |
| logger.info("Dataset is from a randomized controlled trial (RCT)") | |
| rct_priority = [INSTRUMENTAL_VARIABLE, LINEAR_REGRESSION, DIFF_IN_MEANS] | |
| if instrument_var and instrument_var != treatment: | |
| add(INSTRUMENTAL_VARIABLE, | |
| f"RCT encouragement: instrument '{instrument_var}' differs from treatment '{treatment}'.", | |
| rct_priority) | |
| if covariates: | |
| add(LINEAR_REGRESSION, | |
| "RCT with covariates—use OLS for precision.", | |
| rct_priority) | |
| else: | |
| add(DIFF_IN_MEANS, | |
| "Pure RCT without covariates—difference-in-means.", | |
| rct_priority) | |
| # Observational branch | |
| obs_priority_binary = [ | |
| INSTRUMENTAL_VARIABLE, | |
| PROPENSITY_SCORE_MATCHING, | |
| PROPENSITY_SCORE_WEIGHTING, | |
| FRONTDOOR_ADJUSTMENT, | |
| LINEAR_REGRESSION, | |
| ] | |
| obs_priority_nonbinary = [ | |
| INSTRUMENTAL_VARIABLE, | |
| FRONTDOOR_ADJUSTMENT, | |
| LINEAR_REGRESSION, | |
| ] | |
| # Common early structural signals first (still only add as candidates) | |
| if has_temporal and time_var: | |
| add(DIFF_IN_DIFF, | |
| f"Temporal structure via '{time_var}'—consider Difference-in-Differences (assumes parallel trends).", | |
| [DIFF_IN_DIFF]) # highest among itself | |
| if running_var and cutoff_val is not None: | |
| add(REGRESSION_DISCONTINUITY, | |
| f"Running variable '{running_var}' with cutoff {cutoff_val}—consider RDD.", | |
| [REGRESSION_DISCONTINUITY]) | |
| # Binary vs non-binary pathways | |
| if treatment_variable_type == "binary": | |
| if instrument_var: | |
| add(INSTRUMENTAL_VARIABLE, | |
| f"Instrumental variable '{instrument_var}' available.", | |
| obs_priority_binary) | |
| # Propensity score methods only if covariates exist | |
| if covariates: | |
| if covariate_overlap_result is not None: | |
| ps_method = (PROPENSITY_SCORE_WEIGHTING | |
| if covariate_overlap_result < 0.1 | |
| else PROPENSITY_SCORE_MATCHING) | |
| else: | |
| ps_method = PROPENSITY_SCORE_MATCHING | |
| add(ps_method, | |
| "Covariates observed; PS method chosen based on overlap.", | |
| obs_priority_binary) | |
| if frontdoor: | |
| add(FRONTDOOR_ADJUSTMENT, | |
| "Front-door criterion satisfied.", | |
| obs_priority_binary) | |
| add(LINEAR_REGRESSION, | |
| "OLS as a fallback specification.", | |
| obs_priority_binary) | |
| else: | |
| logger.info(f"Non-binary treatment variable detected: {treatment_variable_type}") | |
| if instrument_var: | |
| add(INSTRUMENTAL_VARIABLE, | |
| f"Instrument '{instrument_var}' candidate for non-binary treatment.", | |
| obs_priority_nonbinary) | |
| if frontdoor: | |
| add(FRONTDOOR_ADJUSTMENT, | |
| "Front-door criterion satisfied.", | |
| obs_priority_nonbinary) | |
| add(LINEAR_REGRESSION, | |
| "Fallback for non-binary treatment without stronger identification.", | |
| obs_priority_nonbinary) | |
| # ----- Centralized exclusion handling ----- | |
| # Remove excluded | |
| filtered = [(m, p) for (m, p) in candidates if m not in excluded_methods] | |
| # If nothing survives, attempt a safe fallback not excluded | |
| if not filtered: | |
| logger.warning(f"All candidates excluded. Candidates were: {[m for m,_ in candidates]}. Excluded: {sorted(excluded_methods)}") | |
| fallback_order = [ | |
| LINEAR_REGRESSION, | |
| DIFF_IN_MEANS, | |
| PROPENSITY_SCORE_MATCHING, | |
| PROPENSITY_SCORE_WEIGHTING, | |
| DIFF_IN_DIFF, | |
| REGRESSION_DISCONTINUITY, | |
| INSTRUMENTAL_VARIABLE, | |
| FRONTDOOR_ADJUSTMENT, | |
| ] | |
| fallback = next((m for m in fallback_order if m in justifications and m not in excluded_methods), None) | |
| if not fallback: | |
| # truly nothing left; raise with context | |
| raise RuntimeError("No viable method remains after exclusions.") | |
| selected_method = fallback | |
| alternatives = [] | |
| justifications[selected_method] = justifications.get(selected_method, "Fallback after exclusions.") | |
| else: | |
| # Pick by smallest priority index, then stable by insertion | |
| filtered.sort(key=lambda x: x[1]) | |
| selected_method = filtered[0][0] | |
| alternatives = [m for (m, _) in filtered[1:] if m != selected_method] | |
| logger.info(f"Selected method: {selected_method}; alternatives: {alternatives}") | |
| return { | |
| "selected_method": selected_method, | |
| "method_justification": justifications[selected_method], | |
| "method_assumptions": assumptions[selected_method], | |
| "alternatives": alternatives, | |
| "excluded_methods": sorted(excluded_methods), | |
| } | |
| def rule_based_select_method(dataset_analysis, variables, is_rct, llm, dataset_description, original_query, excluded_methods=None): | |
| """ | |
| Wrapped function to select causal method based on dataset properties and query | |
| Args: | |
| dataset_analysis (Dict): results of dataset analysis | |
| variables (Dict): dictionary of variable names and types | |
| is_rct (bool): whether the dataset is from a randomized controlled trial | |
| llm (BaseChatModel): language model instance for generating prompts | |
| dataset_description (str): description of the dataset | |
| original_query (str): the original user query | |
| excluded_methods (List[str], optional): list of methods to exclude from selection | |
| """ | |
| logger.info("Running rule-based method selection") | |
| properties = {"treatment_variable": variables.get("treatment_variable"), "instrument_variable":variables.get("instrument_variable"), | |
| "covariates": variables.get("covariates", []), "outcome_variable": variables.get("outcome_variable"), | |
| "time_variable": variables.get("time_variable"), "running_variable": variables.get("running_variable"), | |
| "treatment_variable_type": variables.get("treatment_variable_type", "binary"), | |
| "has_temporal_structure": dataset_analysis.get("temporal_structure", False).get("has_temporal_structure", False), | |
| "frontdoor_criterion": variables.get("frontdoor_criterion", False), | |
| "cutoff_value": variables.get("cutoff_value"), | |
| "covariate_overlap_score": variables.get("covariate_overlap_result", 0)} | |
| properties["is_rct"] = is_rct | |
| logger.info(f"Dataset properties for method selection: {properties}") | |
| return select_method(properties, excluded_methods) | |
| class DecisionTreeEngine: | |
| """ | |
| Engine for applying decision trees to select appropriate causal methods. | |
| This class wraps the functional decision tree implementation to provide | |
| an object-oriented interface for method selection. | |
| """ | |
| def __init__(self, verbose=False): | |
| self.verbose = verbose | |
| def select_method(self, df: pd.DataFrame, treatment: str, outcome: str, covariates: List[str], | |
| dataset_analysis: Dict[str, Any], query_details: Dict[str, Any]) -> Dict[str, Any]: | |
| """ | |
| Apply decision tree to select appropriate causal method. | |
| """ | |
| if self.verbose: | |
| print(f"Applying decision tree for treatment: {treatment}, outcome: {outcome}") | |
| print(f"Available covariates: {covariates}") | |
| treatment_variable_type = query_details.get("treatment_variable_type") | |
| covariate_overlap_result = query_details.get("covariate_overlap_result") | |
| info = {"treatment_variable": treatment, "outcome_variable": outcome, | |
| "covariates": covariates, "time_variable": query_details.get("time_variable"), | |
| "group_variable": query_details.get("group_variable"), | |
| "instrument_variable": query_details.get("instrument_variable"), | |
| "running_variable": query_details.get("running_variable"), | |
| "cutoff_value": query_details.get("cutoff_value"), | |
| "is_rct": query_details.get("is_rct", False), | |
| "has_temporal_structure": dataset_analysis.get("temporal_structure", False).get("has_temporal_structure", False), | |
| "frontdoor_criterion": query_details.get("frontdoor_criterion", False), | |
| "covariate_overlap_score": covariate_overlap_result, | |
| "treatment_variable_type": treatment_variable_type} | |
| result = select_method(info) | |
| if self.verbose: | |
| print(f"Selected method: {result['selected_method']}") | |
| print(f"Justification: {result['method_justification']}") | |
| result["decision_path"] = self._get_decision_path(result["selected_method"]) | |
| return result | |
| def _get_decision_path(self, method): | |
| if method == "linear_regression": | |
| return ["Check if randomized experiment", "Data appears to be from a randomized experiment with covariates"] | |
| elif method == "propensity_score_matching": | |
| return ["Check if randomized experiment", "Data is observational", | |
| "Check for sufficient covariate overlap", "Sufficient overlap exists"] | |
| elif method == "propensity_score_weighting": | |
| return ["Check if randomized experiment", "Data is observational", | |
| "Check for sufficient covariate overlap", "Low overlap—weighting preferred"] | |
| elif method == "backdoor_adjustment": | |
| return ["Check if randomized experiment", "Data is observational", | |
| "Check for sufficient covariate overlap", "Adjusting for covariates"] | |
| elif method == "instrumental_variable": | |
| return ["Check if randomized experiment", "Data is observational", | |
| "Check for instrumental variables", "Instrument is available"] | |
| elif method == "regression_discontinuity_design": | |
| return ["Check if randomized experiment", "Data is observational", | |
| "Check for discontinuity", "Discontinuity exists"] | |
| elif method == "difference_in_differences": | |
| return ["Check if randomized experiment", "Data is observational", | |
| "Check for temporal structure", "Panel data structure exists"] | |
| elif method == "frontdoor_adjustment": | |
| return ["Check if randomized experiment", "Data is observational", | |
| "Check front-door criterion", "Front-door path identified"] | |
| elif method == "diff_in_means": | |
| return ["Check if randomized experiment", "Pure RCT without covariates"] | |
| else: | |
| return ["Default method selection"] |