#!/usr/bin/env python3 """ DataFrame Builder for Causal Analysis This module creates DataFrames for causal analysis from provided data. It no longer accesses the database directly and operates as pure functions. """ import pandas as pd import json import os from typing import Union, Dict, List, Optional, Any import logging logger = logging.getLogger(__name__) def create_component_influence_dataframe( perturbation_tests: List[Dict], prompt_reconstructions: List[Dict], relations: List[Dict] ) -> Optional[pd.DataFrame]: """ Create a DataFrame for component influence analysis from provided data. This is a pure function that takes data as parameters instead of querying the database directly. Args: perturbation_tests: List of perturbation test dictionaries prompt_reconstructions: List of prompt reconstruction dictionaries relations: List of relation dictionaries from the knowledge graph Returns: pandas.DataFrame with component features and perturbation scores, or None if creation fails """ try: # Create mapping from relation_id to prompt reconstruction pr_by_relation = {pr['relation_id']: pr for pr in prompt_reconstructions} # Create mapping from relation_id to perturbation test pt_by_relation = {pt['relation_id']: pt for pt in perturbation_tests} # Get all unique entity and relation IDs from dependencies all_entity_ids = set() all_relation_ids = set() # First pass: collect all unique IDs for relation in relations: relation_id = relation.get('id') if not relation_id or relation_id not in pr_by_relation: continue pr = pr_by_relation[relation_id] dependencies = pr.get('dependencies', {}) if isinstance(dependencies, dict): entities = dependencies.get('entities', []) relations_deps = dependencies.get('relations', []) if isinstance(entities, list): all_entity_ids.update(entities) if isinstance(relations_deps, list): all_relation_ids.update(relations_deps) # Create rows for the DataFrame rows = [] # Second pass: create feature rows for i, relation in enumerate(relations): try: print(f"\nProcessing relation {i+1}/{len(relations)}:") print(f"- Relation ID: {relation.get('id', 'unknown')}") print(f"- Relation type: {relation.get('type', 'unknown')}") # Get relation ID relation_id = relation.get('id') if not relation_id: print(f"Skipping relation without ID") continue # Get prompt reconstruction and perturbation test pr = pr_by_relation.get(relation_id) pt = pt_by_relation.get(relation_id) if not pr or not pt: print(f"Skipping relation {relation_id}, missing reconstruction or test") continue print(f"- Found prompt reconstruction and perturbation test") print(f"- Perturbation score: {pt.get('perturbation_score', 0)}") # Create a row for this reconstructed prompt row = { 'relation_id': relation_id, 'relation_type': relation.get('type'), 'source': relation.get('source'), 'target': relation.get('target'), 'perturbation': pt.get('perturbation_score', 0) } # Add binary features for entities dependencies = pr.get('dependencies', {}) entity_deps = dependencies.get('entities', []) if isinstance(dependencies, dict) else [] for entity_id in all_entity_ids: feature_name = f"entity_{entity_id}" row[feature_name] = 1 if entity_id in entity_deps else 0 # Add binary features for relations relation_deps = dependencies.get('relations', []) if isinstance(dependencies, dict) else [] for rel_id in all_relation_ids: feature_name = f"relation_{rel_id}" row[feature_name] = 1 if rel_id in relation_deps else 0 rows.append(row) except Exception as e: print(f"Error processing relation {relation.get('id', 'unknown')}: {str(e)}") continue if not rows: print("No valid rows created") return None # Create DataFrame df = pd.DataFrame(rows) print(f"\nCreated DataFrame with {len(df)} rows and {len(df.columns)} columns") print(f"Columns: {list(df.columns)}") # Basic validation if 'perturbation' not in df.columns: print("ERROR: 'perturbation' column missing from DataFrame") return None # Check for features (entity_ or relation_ columns) feature_cols = [col for col in df.columns if col.startswith(('entity_', 'relation_'))] if not feature_cols: print("WARNING: No feature columns found in DataFrame") else: print(f"Found {len(feature_cols)} feature columns") return df except Exception as e: logger.error(f"Error creating component influence DataFrame: {str(e)}") return None def create_component_influence_dataframe_from_file(input_path: str) -> Optional[pd.DataFrame]: """ Create a DataFrame for component influence analysis from a JSON file. Legacy function maintained for backward compatibility. Args: input_path: Path to the JSON file containing analysis data Returns: pandas.DataFrame with component features and perturbation scores, or None if creation fails """ try: # Load data from file with open(input_path, 'r') as f: data = json.load(f) # Extract components perturbation_tests = data.get('perturbation_tests', []) prompt_reconstructions = data.get('prompt_reconstructions', []) relations = data.get('knowledge_graph', {}).get('relations', []) # Call the pure function return create_component_influence_dataframe( perturbation_tests, prompt_reconstructions, relations ) except Exception as e: logger.error(f"Error creating DataFrame from file {input_path}: {str(e)}") return None def main(): """ Main function for testing the DataFrame builder. """ import argparse parser = argparse.ArgumentParser(description='Test component influence DataFrame creation') parser.add_argument('--input', type=str, required=True, help='Path to input JSON file with analysis data') parser.add_argument('--output', type=str, help='Path to output CSV file (optional)') args = parser.parse_args() # Create DataFrame from file df = create_component_influence_dataframe_from_file(args.input) if df is None: print("ERROR: Failed to create DataFrame") return 1 print(f"Successfully created DataFrame with {len(df)} rows and {len(df.columns)} columns") print(f"Columns: {list(df.columns)}") print(f"Perturbation score stats:") print(f" Mean: {df['perturbation'].mean():.4f}") print(f" Std: {df['perturbation'].std():.4f}") print(f" Min: {df['perturbation'].min():.4f}") print(f" Max: {df['perturbation'].max():.4f}") # Save to CSV if requested if args.output: df.to_csv(args.output, index=False) print(f"DataFrame saved to {args.output}") return 0 if __name__ == "__main__": main()