Spaces:
Running
Running
| """ | |
| Linear Regression Estimator for Causal Inference. | |
| Uses Ordinary Least Squares (OLS) to estimate the treatment effect, potentially | |
| adjusting for covariates. | |
| """ | |
| import pandas as pd | |
| import statsmodels.api as sm | |
| import statsmodels.formula.api as smf | |
| from typing import Dict, Any, List, Optional, Union | |
| import logging | |
| from langchain.chat_models.base import BaseChatModel | |
| import re | |
| import json | |
| from pydantic import BaseModel, ValidationError | |
| from langchain_core.messages import HumanMessage | |
| from langchain_core.exceptions import OutputParserException | |
| from auto_causal.models import LLMIdentifiedRelevantParams | |
| from auto_causal.prompts.regression_prompts import STATSMODELS_PARAMS_IDENTIFICATION_PROMPT_TEMPLATE | |
| from auto_causal.config import get_llm_client | |
| # Placeholder for potential future LLM assistance integration | |
| # from .llm_assist import interpret_lr_results, suggest_lr_covariates | |
| # Placeholder for potential future diagnostics integration | |
| # from .diagnostics import run_lr_diagnostics | |
| logger = logging.getLogger(__name__) | |
| def _call_llm_for_var(llm: BaseChatModel, prompt: str, pydantic_model: BaseModel) -> Optional[BaseModel]: | |
| """Helper to call LLM with structured output and handle errors.""" | |
| try: | |
| messages = [HumanMessage(content=prompt)] | |
| structured_llm = llm.with_structured_output(pydantic_model) | |
| parsed_result = structured_llm.invoke(messages) | |
| return parsed_result | |
| except (OutputParserException, ValidationError) as e: | |
| logger.error(f"LLM call failed parsing/validation for {pydantic_model.__name__}: {e}") | |
| except Exception as e: | |
| logger.error(f"LLM call failed unexpectedly for {pydantic_model.__name__}: {e}", exc_info=True) | |
| return None | |
| # Define module-level helper function | |
| def _clean_variable_name_for_patsy_local(name: str) -> str: | |
| if not isinstance(name, str): | |
| name = str(name) | |
| name = re.sub(r'[^a-zA-Z0-9_]', '_', name) | |
| if not re.match(r'^[a-zA-Z_]', name): | |
| name = 'var_' + name | |
| return name | |
| def estimate_effect( | |
| df: pd.DataFrame, | |
| treatment: str, | |
| outcome: str, | |
| covariates: Optional[List[str]] = None, | |
| query_str: Optional[str] = None, # For potential LLM use | |
| llm: Optional[BaseChatModel] = None, # For potential LLM use | |
| **kwargs # To capture any other potential arguments | |
| ) -> Dict[str, Any]: | |
| """ | |
| Estimates the causal effect using Linear Regression (OLS). | |
| Args: | |
| df: Input DataFrame. | |
| treatment: Name of the treatment variable column. | |
| outcome: Name of the outcome variable column. | |
| covariates: Optional list of covariate names. | |
| query_str: Optional user query for context (e.g., for LLM). | |
| llm: Optional Language Model instance. | |
| **kwargs: Additional keyword arguments. | |
| Returns: | |
| Dictionary containing estimation results: | |
| - 'effect_estimate': The estimated coefficient for the treatment variable. | |
| - 'p_value': The p-value associated with the treatment coefficient. | |
| - 'confidence_interval': The 95% confidence interval for the effect. | |
| - 'standard_error': The standard error of the treatment coefficient. | |
| - 'formula': The regression formula used. | |
| - 'model_summary': Summary object from statsmodels. | |
| - 'diagnostics': Placeholder for diagnostic results. | |
| - 'interpretation': Placeholder for LLM interpretation. | |
| """ | |
| if covariates is None: | |
| covariates = [] | |
| # Retrieve additional args from kwargs | |
| interaction_term_suggested = kwargs.get('interaction_term_suggested', False) | |
| # interaction_variable_candidate is the *original* name from query_interpreter | |
| interaction_variable_candidate_orig_name = kwargs.get('interaction_variable_candidate') | |
| treatment_reference_level = kwargs.get('treatment_reference_level') | |
| column_mappings = kwargs.get('column_mappings', {}) | |
| required_cols = [treatment, outcome] + covariates | |
| # If interaction variable is suggested, ensure it (or its processed form) is in df for analysis | |
| # This check is complex here as interaction_variable_candidate_orig_name needs mapping to processed column(s) | |
| # We'll rely on df_analysis.dropna() and formula construction to handle missing interaction var columns later | |
| missing_cols = [col for col in required_cols if col not in df.columns] | |
| if missing_cols: | |
| raise ValueError(f"Missing required columns: {missing_cols}") | |
| # Prepare data for statsmodels (add constant, handle potential NaNs) | |
| df_analysis = df[required_cols].dropna() | |
| if df_analysis.empty: | |
| raise ValueError("No data remaining after dropping NaNs for required columns.") | |
| X = df_analysis[[treatment] + covariates] | |
| X = sm.add_constant(X) # Add intercept | |
| y = df_analysis[outcome] | |
| # --- Formula Construction --- | |
| outcome_col_name = outcome # Name in processed df | |
| treatment_col_name = treatment # Name in processed df | |
| processed_covariate_col_names = covariates # List of names in processed df | |
| rhs_terms = [] | |
| # 1. Treatment Term | |
| treatment_patsy_term = treatment_col_name # Default | |
| original_treatment_info = column_mappings.get(treatment_col_name, {}) # Info from preprocess_data | |
| is_binary_encoded = original_treatment_info.get('transformed_as') == 'label_encoded_binary' | |
| is_still_categorical_in_df = df_analysis[treatment_col_name].dtype.name in ['object', 'category'] | |
| if is_still_categorical_in_df and not is_binary_encoded: # Covers multi-level and binary categoricals not yet numeric | |
| if treatment_reference_level: | |
| treatment_patsy_term = f"C({treatment_col_name}, Treatment(reference='{treatment_reference_level}'))" | |
| logger.info(f"Treating '{treatment_col_name}' as multi-level categorical with reference '{treatment_reference_level}'.") | |
| else: | |
| # Default C() wrapping for categoricals if no specific reference is given. | |
| # This applies to multi-level or binary categoricals that were not label_encoded to 0/1 by preprocess_data. | |
| treatment_patsy_term = f"C({treatment_col_name})" | |
| logger.info(f"Treating '{treatment_col_name}' as categorical (Patsy will pick reference).") | |
| elif is_binary_encoded: # Was binary and explicitly label encoded to 0/1 by preprocess_data | |
| # Even if it's now numeric 0/1, C() ensures Patsy treats it categorically for parameter naming consistency. | |
| treatment_patsy_term = f"C({treatment_col_name})" | |
| logger.info(f"Treating label-encoded binary '{treatment_col_name}' as categorical for Patsy.") | |
| else: # Assumed to be already numeric (continuous or discrete numeric not needing C() for main effect) | |
| # treatment_patsy_term remains treatment_col_name (default) | |
| logger.info(f"Treating '{treatment_col_name}' as numeric for Patsy formula.") | |
| rhs_terms.append(treatment_patsy_term) | |
| # 2. Covariate Terms | |
| for cov_col_name in processed_covariate_col_names: | |
| if cov_col_name == treatment_col_name: # Should not happen if covariates list is clean | |
| continue | |
| # Assume covariates are already numeric/dummy. If one was object/category in df_analysis (unlikely), C() it. | |
| if df_analysis[cov_col_name].dtype.name in ['object', 'category']: | |
| rhs_terms.append(f"C({cov_col_name})") | |
| else: | |
| rhs_terms.append(cov_col_name) | |
| # 3. Interaction Term (Simplified: interaction_variable_candidate_orig_name must map to a single column in df_analysis) | |
| actual_interaction_term_added_to_formula = None | |
| if interaction_term_suggested and interaction_variable_candidate_orig_name: | |
| processed_interaction_col_name = None | |
| interaction_var_info = column_mappings.get(interaction_variable_candidate_orig_name, {}) | |
| if interaction_var_info.get('transformed_as') == 'one_hot_encoded': | |
| logger.warning(f"Interaction with one-hot encoded variable '{interaction_variable_candidate_orig_name}' is complex. Currently skipping this interaction for Linear Regression.") | |
| elif interaction_var_info.get('new_column_name') and interaction_var_info['new_column_name'] in df_analysis.columns: | |
| processed_interaction_col_name = interaction_var_info['new_column_name'] | |
| elif interaction_variable_candidate_orig_name in df_analysis.columns: # Was not in mappings, or mapping didn't change name (e.g. numeric) | |
| processed_interaction_col_name = interaction_variable_candidate_orig_name | |
| if processed_interaction_col_name: | |
| interaction_var_patsy_term = processed_interaction_col_name | |
| # If the processed interaction column itself is categorical (e.g. label encoded binary) | |
| if df_analysis[processed_interaction_col_name].dtype.name in ['object', 'category', 'bool'] or \ | |
| interaction_var_info.get('original_dtype') in ['bool', 'category']: | |
| interaction_var_patsy_term = f"C({processed_interaction_col_name})" | |
| actual_interaction_term_added_to_formula = f"{treatment_patsy_term}:{interaction_var_patsy_term}" | |
| rhs_terms.append(actual_interaction_term_added_to_formula) | |
| logger.info(f"Adding interaction term to formula: {actual_interaction_term_added_to_formula}") | |
| elif interaction_variable_candidate_orig_name: # Log if it was suggested but couldn't be mapped/found | |
| logger.warning(f"Could not resolve interaction variable candidate '{interaction_variable_candidate_orig_name}' to a single usable column in processed data. Skipping interaction term.") | |
| # Build the formula string for reporting and fitting | |
| if not rhs_terms: # Should always have at least treatment | |
| formula = f"{outcome_col_name} ~ 1" | |
| else: | |
| formula = f"{outcome_col_name} ~ {' + '.join(rhs_terms)}" | |
| logger.info(f"Using formula for Linear Regression: {formula}") | |
| try: | |
| model = smf.ols(formula=formula, data=df_analysis) | |
| results = model.fit() | |
| logger.info("OLS model fitted successfully.") | |
| logger.info(results.summary()) # Changed to debug level for less verbose default logging | |
| # --- Result Extraction: LLM attempt first, then Regex fallback --- | |
| effect_estimates_by_level = {} | |
| all_params_extracted = False # Default to False | |
| llm_extraction_successful = False | |
| # Attempt LLM-based extraction if llm client and query are available | |
| llm = get_llm_client() | |
| if llm and query_str: | |
| logger.info(f"Attempting LLM-based result extraction (informed by query: '{query_str[:50]}...').") | |
| try: | |
| param_names_list = results.params.index.tolist() | |
| param_estimates_list = results.params.tolist() | |
| param_p_values_list = results.pvalues.tolist() | |
| param_std_errs_list = results.bse.tolist() | |
| conf_int_df = results.conf_int(alpha=0.05) | |
| param_conf_ints_low_list = [] | |
| param_conf_ints_high_list = [] | |
| if not conf_int_df.empty and len(conf_int_df.columns) == 2: | |
| aligned_conf_int_df = conf_int_df.reindex(results.params.index) | |
| param_conf_ints_low_list = aligned_conf_int_df.iloc[:, 0].fillna(float('nan')).tolist() | |
| param_conf_ints_high_list = aligned_conf_int_df.iloc[:, 1].fillna(float('nan')).tolist() | |
| else: | |
| nan_list_ci = [float('nan')] * len(param_names_list) | |
| param_conf_ints_low_list = nan_list_ci | |
| param_conf_ints_high_list = nan_list_ci | |
| # Placeholder for the new prompt template tailored for this extraction task | |
| # MOVED TO causalscientist/auto_causal/prompts/regression_prompts.py | |
| is_multilevel_case_for_prompt = bool(treatment_reference_level and is_still_categorical_in_df and not is_binary_encoded) | |
| reference_level_for_prompt_str = str(treatment_reference_level) if is_multilevel_case_for_prompt else "N/A" | |
| indexed_param_names_for_prompt = [f"{idx}: '{name}'" for idx, name in enumerate(param_names_list)] | |
| indexed_param_names_str_for_prompt = "\n".join(indexed_param_names_for_prompt) | |
| prompt_text_for_identification = STATSMODELS_PARAMS_IDENTIFICATION_PROMPT_TEMPLATE.format( | |
| user_query=query_str, | |
| treatment_patsy_term=treatment_patsy_term, | |
| treatment_col_name=treatment_col_name, | |
| is_multilevel_case=is_multilevel_case_for_prompt, | |
| reference_level_for_prompt=reference_level_for_prompt_str, | |
| indexed_param_names_str=indexed_param_names_str_for_prompt, # Pass the indexed list as a string | |
| llm_response_schema_json=json.dumps(LLMIdentifiedRelevantParams.model_json_schema(), indent=2) | |
| ) | |
| llm_identification_response = _call_llm_for_var(llm, prompt_text_for_identification, LLMIdentifiedRelevantParams) | |
| if llm_identification_response and llm_identification_response.identified_params: | |
| logger.info("LLM identified relevant parameters. Proceeding with programmatic extraction.") | |
| for item in llm_identification_response.identified_params: | |
| param_idx = item.param_index | |
| # Validate index against actual list length | |
| if 0 <= param_idx < len(results.params.index): | |
| actual_param_name = results.params.index[param_idx] | |
| # Sanity check if LLM returned name matches actual name at index | |
| if item.param_name != actual_param_name: | |
| logger.warning(f"LLM returned param_name '{item.param_name}' but name at index {param_idx} is '{actual_param_name}'. Using actual name from results.") | |
| current_effect_stats = { | |
| 'estimate': results.params.iloc[param_idx], | |
| 'p_value': results.pvalues.iloc[param_idx], | |
| 'conf_int': results.conf_int(alpha=0.05).iloc[param_idx].tolist(), | |
| 'std_err': results.bse.iloc[param_idx] | |
| } | |
| key_for_effect_dict = 'treatment_effect' # Default for single/binary | |
| if is_multilevel_case_for_prompt: # If it was a multi-level case | |
| match = re.search(r'\[T\.([^]]+)]', actual_param_name) # Use actual_param_name | |
| if match: | |
| level = match.group(1) | |
| if level != reference_level_for_prompt_str: # Ensure it's not the ref level itself | |
| key_for_effect_dict = level | |
| else: | |
| logger.warning(f"Could not parse level from LLM-identified param: {actual_param_name}. Storing with raw name.") | |
| key_for_effect_dict = actual_param_name # Fallback key | |
| effect_estimates_by_level[key_for_effect_dict] = current_effect_stats | |
| else: | |
| logger.warning(f"LLM returned an invalid parameter index: {param_idx}. Skipping.") | |
| if effect_estimates_by_level: # If any effects were successfully processed | |
| all_params_extracted = llm_identification_response.all_parameters_successfully_identified | |
| llm_extraction_successful = True | |
| logger.info(f"Successfully processed LLM-identified parameters. all_parameters_successfully_identified={all_params_extracted}") | |
| print(f"effect_estimates_by_level: {effect_estimates_by_level}") | |
| else: | |
| logger.warning("LLM identified parameters, but none could be processed into effects_estimates_by_level. Falling back to regex.") | |
| else: | |
| logger.warning("LLM parameter identification did not yield usable parameters. Falling back to regex.") | |
| except Exception as e_llm: | |
| logger.warning(f"LLM-based result extraction failed: {e_llm}. Falling back to regex.", exc_info=True) | |
| # --- End of Existing Regex Logic Block --- | |
| # Primary effect_estimate for simple reporting (e.g. first level or the only one) | |
| # For multi-level, this is ambiguous. For now, let's report None or the first one. | |
| # The full details are in effect_estimates_by_level. | |
| main_effect_estimate = None | |
| main_p_value = None | |
| main_conf_int = [None, None] # Default for single or if no effects | |
| main_std_err = None | |
| if effect_estimates_by_level: | |
| if 'treatment_effect' in effect_estimates_by_level: # Single effect case | |
| single_effect_data = effect_estimates_by_level['treatment_effect'] | |
| main_effect_estimate = single_effect_data['estimate'] | |
| main_p_value = single_effect_data['p_value'] | |
| main_conf_int = single_effect_data['conf_int'] | |
| main_std_err = single_effect_data['std_err'] | |
| else: # Multi-level case | |
| logger.info("Multi-level treatment effects extracted. Populating dicts for main estimate fields.") | |
| effect_estimate_dict = {} | |
| p_value_dict = {} | |
| conf_int_dict = {} | |
| std_err_dict = {} | |
| for level, stats in effect_estimates_by_level.items(): | |
| effect_estimate_dict[level] = stats.get('estimate') | |
| p_value_dict[level] = stats.get('p_value') | |
| conf_int_dict[level] = stats.get('conf_int') # This is already a list [low, high] | |
| std_err_dict[level] = stats.get('std_err') | |
| main_effect_estimate = effect_estimate_dict | |
| main_p_value = p_value_dict | |
| main_conf_int = conf_int_dict | |
| main_std_err = std_err_dict | |
| interpretation_details = {} | |
| if actual_interaction_term_added_to_formula and actual_interaction_term_added_to_formula in results.params.index: | |
| interpretation_details['interaction_term_coefficient'] = results.params[actual_interaction_term_added_to_formula] | |
| interpretation_details['interaction_term_p_value'] = results.pvalues[actual_interaction_term_added_to_formula] | |
| logger.info(f"Interaction term '{actual_interaction_term_added_to_formula}' coeff: {interpretation_details['interaction_term_coefficient']}") | |
| diag_results = {} | |
| interpretation = "Interpretation not available." | |
| output_dict = { | |
| 'effect_estimate': main_effect_estimate, | |
| 'p_value': main_p_value, | |
| 'confidence_interval': main_conf_int, | |
| 'standard_error': main_std_err, | |
| 'estimated_effects_by_level': effect_estimates_by_level if (treatment_reference_level and is_still_categorical_in_df and not is_binary_encoded and effect_estimates_by_level) else None, | |
| 'reference_level_used': treatment_reference_level if (treatment_reference_level and is_still_categorical_in_df and not is_binary_encoded) else None, | |
| 'formula': formula, | |
| 'model_summary_text': results.summary().as_text(), # Store as text for easier serialization | |
| 'diagnostics': diag_results, | |
| 'interpretation_details': interpretation_details, # Added interaction details | |
| 'interpretation': interpretation, | |
| 'method_used': 'Linear Regression (OLS)' | |
| } | |
| if not all_params_extracted: | |
| output_dict['warnings'] = ["Could not reliably extract all requested parameters from model results. Please check model_summary_text."] | |
| return output_dict | |
| except Exception as e: | |
| logger.error(f"Linear Regression failed: {e}") | |
| raise # Re-raise the exception after logging |