Spaces:
Running
Running
| """ | |
| Method Selector Tool for selecting causal inference methods. | |
| This module provides a LangChain tool for selecting appropriate | |
| causal inference methods based on dataset characteristics and query details. | |
| """ | |
| import logging # Add logging | |
| from typing import Dict, List, Any, Optional, Union | |
| from langchain_core.tools import tool # Use langchain_core | |
| # Import component function and central LLM factory | |
| from auto_causal.components.decision_tree import rule_based_select_method # Rule-based | |
| from auto_causal.components.decision_tree_llm import DecisionTreeLLMEngine # LLM-based | |
| from auto_causal.config import get_llm_client # Updated import path | |
| from auto_causal.components.state_manager import create_workflow_state_update | |
| # Import shared models from central location | |
| from auto_causal.models import ( | |
| Variables, | |
| DatasetAnalysis, | |
| MethodSelectorInput # Still needed for args_schema | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Option 1: Modify signature to match args_schema fields | |
| def method_selector_tool( | |
| variables: Variables, | |
| dataset_analysis: DatasetAnalysis, | |
| dataset_description: Optional[str] = None, | |
| original_query: Optional[str] = None, | |
| excluded_methods: Optional[List[str]] = None | |
| ) -> Dict[str, Any]: | |
| """ | |
| Select the most appropriate causal inference method based on structured input. | |
| Applies decision logic based on dataset analysis and identified variables (including is_rct). | |
| Args: | |
| variables: Pydantic model containing identified variables (T, O, C, IV, RDD, is_rct, etc.). | |
| dataset_analysis: Pydantic model containing results of dataset analysis. | |
| dataset_description: Optional textual description of the dataset. | |
| original_query: Optional original user query string. | |
| excluded_methods: Optional list of method names to exclude from selection. | |
| Returns: | |
| Dictionary with method selection details, context for next step, and workflow state. | |
| """ | |
| logger.info("Running method_selector_tool with individual args...") | |
| # Access data directly from arguments (they are already Pydantic models) | |
| variables_model = variables | |
| dataset_analysis_model = dataset_analysis | |
| dataset_description_str = dataset_description | |
| is_rct_flag = variables_model.is_rct # Get is_rct directly from variables argument | |
| # Convert Pydantic models to dicts for the component call (select_method expects dicts) | |
| variables_dict = variables_model.model_dump() | |
| dataset_analysis_dict = dataset_analysis_model.model_dump() | |
| # Basic validation | |
| treatment = variables_dict.get("treatment_variable") | |
| outcome = variables_dict.get("outcome_variable") | |
| if not all([treatment, outcome]): | |
| logger.error("Missing treatment or outcome variable in input.") | |
| # Construct error output, including passed-along context | |
| workflow_update = create_workflow_state_update( | |
| current_step="method_selection", | |
| step_completed_flag=False, | |
| next_tool="method_selector_tool", | |
| next_step_reason="Missing treatment/outcome variable in input", | |
| error="Missing treatment/outcome variable in input" | |
| ) | |
| # Use model_dump() for analysis dict | |
| return { "error": "Missing treatment/outcome", | |
| "variables": variables_dict, | |
| "dataset_analysis": dataset_analysis_dict, | |
| "dataset_description": dataset_description_str, | |
| **workflow_update.get('workflow_state', {})} | |
| # Get LLM instance (optional for component) | |
| try: | |
| llm_instance = get_llm_client() | |
| except Exception as e: | |
| logger.warning(f"Failed to initialize LLM for method_selector_tool: {e}. Proceeding without LLM features.") | |
| llm_instance = None | |
| # --- Configuration for switching --- | |
| USE_LLM_DECISION_TREE = False # Set to False to use the original rule-based tree | |
| # Call the component function | |
| try: | |
| if USE_LLM_DECISION_TREE: | |
| logger.info("Using LLM-based Decision Tree Engine for method selection.") | |
| if not llm_instance: | |
| logger.warning("LLM instance is required for DecisionTreeLLMEngine but not available. Falling back to rule-based or error.") | |
| # Potentially raise an error or explicitly call rule-based here if LLM is mandatory for this path | |
| # For now, it will proceed and DecisionTreeLLMEngine will handle the missing llm | |
| llm_engine = DecisionTreeLLMEngine(verbose=True) # You can set verbosity as needed | |
| method_selection_dict = llm_engine.select_method_llm( | |
| dataset_analysis=dataset_analysis_dict, | |
| variables=variables_dict, | |
| is_rct=is_rct_flag if isinstance(is_rct_flag, bool) else False, | |
| llm=llm_instance, | |
| excluded_methods=excluded_methods | |
| ) | |
| else: | |
| logger.info("Using Rule-based Decision Tree Engine for method selection.") | |
| # Pass dicts and the is_rct flag | |
| method_selection_dict = rule_based_select_method( | |
| dataset_analysis=dataset_analysis_dict, | |
| variables=variables_dict, | |
| is_rct=is_rct_flag if isinstance(is_rct_flag, bool) else False, # Handle None case | |
| llm=llm_instance, | |
| dataset_description = dataset_description, | |
| original_query = original_query, | |
| excluded_methods = excluded_methods | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error during method selection execution: {e}", exc_info=True) | |
| # Construct error output | |
| workflow_update = create_workflow_state_update( | |
| current_step="method_selection", | |
| step_completed_flag=False, | |
| next_tool="error_handler_tool", | |
| next_step_reason=f"Component failed: {e}", | |
| error=f"Component failed: {e}" | |
| ) | |
| return { "error": f"Method selection logic failed: {e}", | |
| "variables": variables_dict, | |
| "dataset_analysis": dataset_analysis_dict, | |
| "dataset_description": dataset_description_str, | |
| **workflow_update.get('workflow_state', {})} | |
| # --- Prepare Output Dictionary --- | |
| method_selected_flag = bool(method_selection_dict.get("selected_method") and method_selection_dict["selected_method"] != "Error") | |
| # Create the 'method_info' sub-dictionary required by the validator | |
| # Include alternative_methods if present in the selection output | |
| method_info = { | |
| "selected_method": method_selection_dict.get("selected_method"), | |
| "method_name": method_selection_dict.get("selected_method", "").replace("_", " ").title() if method_selected_flag else None, | |
| "method_justification": method_selection_dict.get("method_justification"), | |
| "method_assumptions": method_selection_dict.get("method_assumptions", []), | |
| "alternative_methods": method_selection_dict.get("alternatives", []) # Include alternatives | |
| } | |
| # Create the final output dictionary for the agent | |
| result = { | |
| "method_info": method_info, | |
| "variables": variables_dict, | |
| "dataset_analysis": dataset_analysis_dict, | |
| "dataset_description": dataset_description_str, | |
| "original_query": original_query # Pass original query argument | |
| } | |
| # Determine workflow state for the next step | |
| next_tool_name = "method_validator_tool" if method_selected_flag else "error_handler_tool" | |
| next_reason = "Now we need to validate the assumptions of the selected method" if method_selected_flag else "Method selection failed or returned an error." | |
| workflow_update = create_workflow_state_update( | |
| current_step="method_selection", | |
| step_completed_flag=method_selected_flag, | |
| next_tool=next_tool_name, | |
| next_step_reason=next_reason | |
| ) | |
| result.update(workflow_update.get('workflow_state', {})) # Add workflow state dict | |
| logger.info(f"method_selector_tool finished. Selected: {method_info.get('selected_method')}") | |
| return result |