""" Shared Utility Functions for Causal Analysis This module contains utility functions that are used across multiple causal analysis methods to avoid code duplication. """ import pandas as pd import numpy as np from typing import Dict, List, Any, Union import logging logger = logging.getLogger(__name__) def create_mock_perturbation_scores( num_components: int = 10, num_tests: int = 50, score_range: tuple = (0.1, 0.9), seed: int = 42 ) -> pd.DataFrame: """ Create mock perturbation scores for testing causal analysis methods. Args: num_components: Number of components to generate num_tests: Number of perturbation tests per component score_range: Range of scores (min, max) seed: Random seed for reproducibility Returns: DataFrame with component perturbation scores """ np.random.seed(seed) data = [] for comp_id in range(num_components): component_name = f"component_{comp_id:03d}" for test_id in range(num_tests): score = np.random.uniform(score_range[0], score_range[1]) # Add some realistic patterns if comp_id < 3: # Make first few components more influential score *= 1.2 if test_id % 10 == 0: # Add some noise score *= np.random.uniform(0.8, 1.2) data.append({ 'component': component_name, 'test_id': test_id, 'perturbation_score': min(1.0, score), 'relation_id': f"rel_{comp_id}_{test_id}", 'perturbation_type': np.random.choice(['jailbreak', 'counterfactual_bias']) }) return pd.DataFrame(data) def list_available_components(df: pd.DataFrame) -> List[str]: """ Extract the list of available components from a perturbation DataFrame. Args: df: DataFrame containing perturbation data Returns: List of unique component names """ if 'component' in df.columns: return sorted(df['component'].unique().tolist()) elif 'relation_id' in df.columns: # Extract component names from relation IDs if component column doesn't exist components = [] for rel_id in df['relation_id'].unique(): if isinstance(rel_id, str) and '_' in rel_id: # Assume format like "component_001_test_id" or "rel_comp_id" parts = rel_id.split('_') if len(parts) >= 2: component = f"{parts[0]}_{parts[1]}" components.append(component) return sorted(list(set(components))) else: logger.warning("DataFrame does not contain 'component' or 'relation_id' columns") return [] def validate_analysis_data(analysis_data: Dict[str, Any]) -> bool: """ Validate that analysis data contains required fields for causal analysis. Args: analysis_data: Dictionary containing analysis data Returns: True if data is valid, False otherwise """ required_fields = ['perturbation_tests', 'knowledge_graph', 'perturbation_scores'] for field in required_fields: if field not in analysis_data: logger.error(f"Missing required field: {field}") return False if not analysis_data['perturbation_tests']: logger.error("No perturbation tests found in analysis data") return False if not analysis_data['perturbation_scores']: logger.error("No perturbation scores found in analysis data") return False return True def extract_component_scores(analysis_data: Dict[str, Any]) -> Dict[str, float]: """ Extract component scores from analysis data in a standardized format. Args: analysis_data: Dictionary containing analysis data Returns: Dictionary mapping component names to their scores """ if not validate_analysis_data(analysis_data): return {} component_scores = {} # Extract scores from perturbation_scores for relation_id, score in analysis_data['perturbation_scores'].items(): if isinstance(score, (int, float)) and not np.isnan(score): component_scores[relation_id] = float(score) return component_scores def calculate_component_statistics(scores: Dict[str, float]) -> Dict[str, float]: """ Calculate statistical measures for component scores. Args: scores: Dictionary of component scores Returns: Dictionary with statistical measures """ if not scores: return {} values = list(scores.values()) return { 'mean': np.mean(values), 'median': np.median(values), 'std': np.std(values), 'min': np.min(values), 'max': np.max(values), 'count': len(values) }