File size: 8,262 Bytes
c2ea5ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
#!/usr/bin/env python3
"""
DataFrame Builder for Causal Analysis

This module creates DataFrames for causal analysis from provided data.
It no longer accesses the database directly and operates as pure functions.
"""

import pandas as pd
import json
import os
from typing import Union, Dict, List, Optional, Any
import logging

logger = logging.getLogger(__name__)

def create_component_influence_dataframe(
    perturbation_tests: List[Dict],
    prompt_reconstructions: List[Dict],
    relations: List[Dict]
) -> Optional[pd.DataFrame]:
    """
    Create a DataFrame for component influence analysis from provided data.
    
    This is a pure function that takes data as parameters instead of
    querying the database directly.
    
    Args:
        perturbation_tests: List of perturbation test dictionaries
        prompt_reconstructions: List of prompt reconstruction dictionaries  
        relations: List of relation dictionaries from the knowledge graph
        
    Returns:
        pandas.DataFrame with component features and perturbation scores,
        or None if creation fails
    """
    try:
        # Create mapping from relation_id to prompt reconstruction
        pr_by_relation = {pr['relation_id']: pr for pr in prompt_reconstructions}
        
        # Create mapping from relation_id to perturbation test
        pt_by_relation = {pt['relation_id']: pt for pt in perturbation_tests}
        
        # Get all unique entity and relation IDs from dependencies
        all_entity_ids = set()
        all_relation_ids = set()
        
        # First pass: collect all unique IDs
        for relation in relations:
            relation_id = relation.get('id')
            if not relation_id or relation_id not in pr_by_relation:
                continue
                
            pr = pr_by_relation[relation_id]
            dependencies = pr.get('dependencies', {})
            
            if isinstance(dependencies, dict):
                entities = dependencies.get('entities', [])
                relations_deps = dependencies.get('relations', [])
                
                if isinstance(entities, list):
                    all_entity_ids.update(entities)
                if isinstance(relations_deps, list):
                    all_relation_ids.update(relations_deps)
        
        # Create rows for the DataFrame
        rows = []
        
        # Second pass: create feature rows
        for i, relation in enumerate(relations):
            try:
                print(f"\nProcessing relation {i+1}/{len(relations)}:")
                print(f"- Relation ID: {relation.get('id', 'unknown')}")
                print(f"- Relation type: {relation.get('type', 'unknown')}")
                
                # Get relation ID
                relation_id = relation.get('id')
                if not relation_id:
                    print(f"Skipping relation without ID")
                    continue
                
                # Get prompt reconstruction and perturbation test
                pr = pr_by_relation.get(relation_id)
                pt = pt_by_relation.get(relation_id)
                
                if not pr or not pt:
                    print(f"Skipping relation {relation_id}, missing reconstruction or test")
                    continue
                
                print(f"- Found prompt reconstruction and perturbation test")
                print(f"- Perturbation score: {pt.get('perturbation_score', 0)}")
                
                # Create a row for this reconstructed prompt
                row = {
                    'relation_id': relation_id,
                    'relation_type': relation.get('type'),
                    'source': relation.get('source'),
                    'target': relation.get('target'),
                    'perturbation': pt.get('perturbation_score', 0)
                }
                
                # Add binary features for entities
                dependencies = pr.get('dependencies', {})
                entity_deps = dependencies.get('entities', []) if isinstance(dependencies, dict) else []
                
                for entity_id in all_entity_ids:
                    feature_name = f"entity_{entity_id}"
                    row[feature_name] = 1 if entity_id in entity_deps else 0
                
                # Add binary features for relations
                relation_deps = dependencies.get('relations', []) if isinstance(dependencies, dict) else []
                
                for rel_id in all_relation_ids:
                    feature_name = f"relation_{rel_id}"
                    row[feature_name] = 1 if rel_id in relation_deps else 0
                
                rows.append(row)
                
            except Exception as e:
                print(f"Error processing relation {relation.get('id', 'unknown')}: {str(e)}")
                continue
        
        if not rows:
            print("No valid rows created")
            return None
        
        # Create DataFrame
        df = pd.DataFrame(rows)
        print(f"\nCreated DataFrame with {len(df)} rows and {len(df.columns)} columns")
        print(f"Columns: {list(df.columns)}")
        
        # Basic validation
        if 'perturbation' not in df.columns:
            print("ERROR: 'perturbation' column missing from DataFrame")
            return None
        
        # Check for features (entity_ or relation_ columns)
        feature_cols = [col for col in df.columns if col.startswith(('entity_', 'relation_'))]
        if not feature_cols:
            print("WARNING: No feature columns found in DataFrame")
        else:
            print(f"Found {len(feature_cols)} feature columns")
        
        return df
        
    except Exception as e:
        logger.error(f"Error creating component influence DataFrame: {str(e)}")
        return None

def create_component_influence_dataframe_from_file(input_path: str) -> Optional[pd.DataFrame]:
    """
    Create a DataFrame for component influence analysis from a JSON file.
    
    Legacy function maintained for backward compatibility.
    
    Args:
        input_path: Path to the JSON file containing analysis data
        
    Returns:
        pandas.DataFrame with component features and perturbation scores,
        or None if creation fails
    """
    try:
        # Load data from file
        with open(input_path, 'r') as f:
            data = json.load(f)
        
        # Extract components
        perturbation_tests = data.get('perturbation_tests', [])
        prompt_reconstructions = data.get('prompt_reconstructions', [])
        relations = data.get('knowledge_graph', {}).get('relations', [])
        
        # Call the pure function
        return create_component_influence_dataframe(
            perturbation_tests, prompt_reconstructions, relations
        )
        
    except Exception as e:
        logger.error(f"Error creating DataFrame from file {input_path}: {str(e)}")
        return None

def main():
    """
    Main function for testing the DataFrame builder.
    """
    import argparse
    
    parser = argparse.ArgumentParser(description='Test component influence DataFrame creation')
    parser.add_argument('--input', type=str, required=True, help='Path to input JSON file with analysis data')
    parser.add_argument('--output', type=str, help='Path to output CSV file (optional)')
    
    args = parser.parse_args()
    
    # Create DataFrame from file
    df = create_component_influence_dataframe_from_file(args.input)
    
    if df is None:
        print("ERROR: Failed to create DataFrame")
        return 1
    
    print(f"Successfully created DataFrame with {len(df)} rows and {len(df.columns)} columns")
    print(f"Columns: {list(df.columns)}")
    print(f"Perturbation score stats:")
    print(f"  Mean: {df['perturbation'].mean():.4f}")
    print(f"  Std: {df['perturbation'].std():.4f}")
    print(f"  Min: {df['perturbation'].min():.4f}")
    print(f"  Max: {df['perturbation'].max():.4f}")
    
    # Save to CSV if requested
    if args.output:
        df.to_csv(args.output, index=False)
        print(f"DataFrame saved to {args.output}")
    
    return 0

if __name__ == "__main__":
    main()