Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| DataFrame Builder for Causal Analysis | |
| This module creates DataFrames for causal analysis from provided data. | |
| It no longer accesses the database directly and operates as pure functions. | |
| """ | |
| import pandas as pd | |
| import json | |
| import os | |
| from typing import Union, Dict, List, Optional, Any | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| def create_component_influence_dataframe( | |
| perturbation_tests: List[Dict], | |
| prompt_reconstructions: List[Dict], | |
| relations: List[Dict] | |
| ) -> Optional[pd.DataFrame]: | |
| """ | |
| Create a DataFrame for component influence analysis from provided data. | |
| This is a pure function that takes data as parameters instead of | |
| querying the database directly. | |
| Args: | |
| perturbation_tests: List of perturbation test dictionaries | |
| prompt_reconstructions: List of prompt reconstruction dictionaries | |
| relations: List of relation dictionaries from the knowledge graph | |
| Returns: | |
| pandas.DataFrame with component features and perturbation scores, | |
| or None if creation fails | |
| """ | |
| try: | |
| # Create mapping from relation_id to prompt reconstruction | |
| pr_by_relation = {pr['relation_id']: pr for pr in prompt_reconstructions} | |
| # Create mapping from relation_id to perturbation test | |
| pt_by_relation = {pt['relation_id']: pt for pt in perturbation_tests} | |
| # Get all unique entity and relation IDs from dependencies | |
| all_entity_ids = set() | |
| all_relation_ids = set() | |
| # First pass: collect all unique IDs | |
| for relation in relations: | |
| relation_id = relation.get('id') | |
| if not relation_id or relation_id not in pr_by_relation: | |
| continue | |
| pr = pr_by_relation[relation_id] | |
| dependencies = pr.get('dependencies', {}) | |
| if isinstance(dependencies, dict): | |
| entities = dependencies.get('entities', []) | |
| relations_deps = dependencies.get('relations', []) | |
| if isinstance(entities, list): | |
| all_entity_ids.update(entities) | |
| if isinstance(relations_deps, list): | |
| all_relation_ids.update(relations_deps) | |
| # Create rows for the DataFrame | |
| rows = [] | |
| # Second pass: create feature rows | |
| for i, relation in enumerate(relations): | |
| try: | |
| print(f"\nProcessing relation {i+1}/{len(relations)}:") | |
| print(f"- Relation ID: {relation.get('id', 'unknown')}") | |
| print(f"- Relation type: {relation.get('type', 'unknown')}") | |
| # Get relation ID | |
| relation_id = relation.get('id') | |
| if not relation_id: | |
| print(f"Skipping relation without ID") | |
| continue | |
| # Get prompt reconstruction and perturbation test | |
| pr = pr_by_relation.get(relation_id) | |
| pt = pt_by_relation.get(relation_id) | |
| if not pr or not pt: | |
| print(f"Skipping relation {relation_id}, missing reconstruction or test") | |
| continue | |
| print(f"- Found prompt reconstruction and perturbation test") | |
| print(f"- Perturbation score: {pt.get('perturbation_score', 0)}") | |
| # Create a row for this reconstructed prompt | |
| row = { | |
| 'relation_id': relation_id, | |
| 'relation_type': relation.get('type'), | |
| 'source': relation.get('source'), | |
| 'target': relation.get('target'), | |
| 'perturbation': pt.get('perturbation_score', 0) | |
| } | |
| # Add binary features for entities | |
| dependencies = pr.get('dependencies', {}) | |
| entity_deps = dependencies.get('entities', []) if isinstance(dependencies, dict) else [] | |
| for entity_id in all_entity_ids: | |
| feature_name = f"entity_{entity_id}" | |
| row[feature_name] = 1 if entity_id in entity_deps else 0 | |
| # Add binary features for relations | |
| relation_deps = dependencies.get('relations', []) if isinstance(dependencies, dict) else [] | |
| for rel_id in all_relation_ids: | |
| feature_name = f"relation_{rel_id}" | |
| row[feature_name] = 1 if rel_id in relation_deps else 0 | |
| rows.append(row) | |
| except Exception as e: | |
| print(f"Error processing relation {relation.get('id', 'unknown')}: {str(e)}") | |
| continue | |
| if not rows: | |
| print("No valid rows created") | |
| return None | |
| # Create DataFrame | |
| df = pd.DataFrame(rows) | |
| print(f"\nCreated DataFrame with {len(df)} rows and {len(df.columns)} columns") | |
| print(f"Columns: {list(df.columns)}") | |
| # Basic validation | |
| if 'perturbation' not in df.columns: | |
| print("ERROR: 'perturbation' column missing from DataFrame") | |
| return None | |
| # Check for features (entity_ or relation_ columns) | |
| feature_cols = [col for col in df.columns if col.startswith(('entity_', 'relation_'))] | |
| if not feature_cols: | |
| print("WARNING: No feature columns found in DataFrame") | |
| else: | |
| print(f"Found {len(feature_cols)} feature columns") | |
| return df | |
| except Exception as e: | |
| logger.error(f"Error creating component influence DataFrame: {str(e)}") | |
| return None | |
| def create_component_influence_dataframe_from_file(input_path: str) -> Optional[pd.DataFrame]: | |
| """ | |
| Create a DataFrame for component influence analysis from a JSON file. | |
| Legacy function maintained for backward compatibility. | |
| Args: | |
| input_path: Path to the JSON file containing analysis data | |
| Returns: | |
| pandas.DataFrame with component features and perturbation scores, | |
| or None if creation fails | |
| """ | |
| try: | |
| # Load data from file | |
| with open(input_path, 'r') as f: | |
| data = json.load(f) | |
| # Extract components | |
| perturbation_tests = data.get('perturbation_tests', []) | |
| prompt_reconstructions = data.get('prompt_reconstructions', []) | |
| relations = data.get('knowledge_graph', {}).get('relations', []) | |
| # Call the pure function | |
| return create_component_influence_dataframe( | |
| perturbation_tests, prompt_reconstructions, relations | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error creating DataFrame from file {input_path}: {str(e)}") | |
| return None | |
| def main(): | |
| """ | |
| Main function for testing the DataFrame builder. | |
| """ | |
| import argparse | |
| parser = argparse.ArgumentParser(description='Test component influence DataFrame creation') | |
| parser.add_argument('--input', type=str, required=True, help='Path to input JSON file with analysis data') | |
| parser.add_argument('--output', type=str, help='Path to output CSV file (optional)') | |
| args = parser.parse_args() | |
| # Create DataFrame from file | |
| df = create_component_influence_dataframe_from_file(args.input) | |
| if df is None: | |
| print("ERROR: Failed to create DataFrame") | |
| return 1 | |
| print(f"Successfully created DataFrame with {len(df)} rows and {len(df.columns)} columns") | |
| print(f"Columns: {list(df.columns)}") | |
| print(f"Perturbation score stats:") | |
| print(f" Mean: {df['perturbation'].mean():.4f}") | |
| print(f" Std: {df['perturbation'].std():.4f}") | |
| print(f" Min: {df['perturbation'].min():.4f}") | |
| print(f" Max: {df['perturbation'].max():.4f}") | |
| # Save to CSV if requested | |
| if args.output: | |
| df.to_csv(args.output, index=False) | |
| print(f"DataFrame saved to {args.output}") | |
| return 0 | |
| if __name__ == "__main__": | |
| main() |