File size: 4,972 Bytes
c2ea5ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
"""
Shared Utility Functions for Causal Analysis

This module contains utility functions that are used across multiple
causal analysis methods to avoid code duplication.
"""

import pandas as pd
import numpy as np
from typing import Dict, List, Any, Union
import logging

logger = logging.getLogger(__name__)

def create_mock_perturbation_scores(
    num_components: int = 10,
    num_tests: int = 50,
    score_range: tuple = (0.1, 0.9),
    seed: int = 42
) -> pd.DataFrame:
    """
    Create mock perturbation scores for testing causal analysis methods.
    
    Args:
        num_components: Number of components to generate
        num_tests: Number of perturbation tests per component
        score_range: Range of scores (min, max)
        seed: Random seed for reproducibility
        
    Returns:
        DataFrame with component perturbation scores
    """
    np.random.seed(seed)
    
    data = []
    for comp_id in range(num_components):
        component_name = f"component_{comp_id:03d}"
        for test_id in range(num_tests):
            score = np.random.uniform(score_range[0], score_range[1])
            # Add some realistic patterns
            if comp_id < 3:  # Make first few components more influential
                score *= 1.2
            if test_id % 10 == 0:  # Add some noise
                score *= np.random.uniform(0.8, 1.2)
                
            data.append({
                'component': component_name,
                'test_id': test_id,
                'perturbation_score': min(1.0, score),
                'relation_id': f"rel_{comp_id}_{test_id}",
                'perturbation_type': np.random.choice(['jailbreak', 'counterfactual_bias'])
            })
    
    return pd.DataFrame(data)

def list_available_components(df: pd.DataFrame) -> List[str]:
    """
    Extract the list of available components from a perturbation DataFrame.
    
    Args:
        df: DataFrame containing perturbation data
        
    Returns:
        List of unique component names
    """
    if 'component' in df.columns:
        return sorted(df['component'].unique().tolist())
    elif 'relation_id' in df.columns:
        # Extract component names from relation IDs if component column doesn't exist
        components = []
        for rel_id in df['relation_id'].unique():
            if isinstance(rel_id, str) and '_' in rel_id:
                # Assume format like "component_001_test_id" or "rel_comp_id"
                parts = rel_id.split('_')
                if len(parts) >= 2:
                    component = f"{parts[0]}_{parts[1]}"
                    components.append(component)
        return sorted(list(set(components)))
    else:
        logger.warning("DataFrame does not contain 'component' or 'relation_id' columns")
        return []

def validate_analysis_data(analysis_data: Dict[str, Any]) -> bool:
    """
    Validate that analysis data contains required fields for causal analysis.
    
    Args:
        analysis_data: Dictionary containing analysis data
        
    Returns:
        True if data is valid, False otherwise
    """
    required_fields = ['perturbation_tests', 'knowledge_graph', 'perturbation_scores']
    
    for field in required_fields:
        if field not in analysis_data:
            logger.error(f"Missing required field: {field}")
            return False
    
    if not analysis_data['perturbation_tests']:
        logger.error("No perturbation tests found in analysis data")
        return False
        
    if not analysis_data['perturbation_scores']:
        logger.error("No perturbation scores found in analysis data")
        return False
        
    return True

def extract_component_scores(analysis_data: Dict[str, Any]) -> Dict[str, float]:
    """
    Extract component scores from analysis data in a standardized format.
    
    Args:
        analysis_data: Dictionary containing analysis data
        
    Returns:
        Dictionary mapping component names to their scores
    """
    if not validate_analysis_data(analysis_data):
        return {}
        
    component_scores = {}
    
    # Extract scores from perturbation_scores
    for relation_id, score in analysis_data['perturbation_scores'].items():
        if isinstance(score, (int, float)) and not np.isnan(score):
            component_scores[relation_id] = float(score)
    
    return component_scores

def calculate_component_statistics(scores: Dict[str, float]) -> Dict[str, float]:
    """
    Calculate statistical measures for component scores.
    
    Args:
        scores: Dictionary of component scores
        
    Returns:
        Dictionary with statistical measures
    """
    if not scores:
        return {}
        
    values = list(scores.values())
    
    return {
        'mean': np.mean(values),
        'median': np.median(values),
        'std': np.std(values),
        'min': np.min(values),
        'max': np.max(values),
        'count': len(values)
    }