Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Component Influence Analysis | |
| This script analyzes the influence of knowledge graph components on perturbation scores | |
| using the DataFrame created by the create_component_influence_dataframe function. | |
| """ | |
| import os | |
| import pandas as pd | |
| import numpy as np | |
| from sklearn.ensemble import RandomForestRegressor | |
| from sklearn.metrics import mean_squared_error, r2_score | |
| import logging | |
| from typing import Optional, Dict, List, Tuple, Any | |
| import sys | |
| from sklearn.linear_model import LinearRegression | |
| # Import from the same directory | |
| from .utils.dataframe_builder import create_component_influence_dataframe | |
| # Configure logging for this module | |
| logger = logging.getLogger(__name__) | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| def analyze_component_influence(df: pd.DataFrame, n_estimators: int = 100, | |
| random_state: int = 42) -> Tuple[Optional[RandomForestRegressor], Dict[str, float], List[str]]: | |
| """ | |
| Analyzes the influence of components on perturbation scores. | |
| Uses a linear model to directly estimate the effect size and direction. | |
| Random Forest is still trained as a secondary model for comparison. | |
| Args: | |
| df: DataFrame with binary component features and perturbation score | |
| n_estimators: Number of trees in the Random Forest | |
| random_state: Random seed for reproducibility | |
| Returns: | |
| A tuple containing: | |
| - The trained RandomForestRegressor model (or None if training fails) | |
| - Dictionary of feature importances with sign (direction) | |
| - List of feature columns used for training | |
| """ | |
| # Extract feature columns (all columns starting with "entity_" or "relation_") | |
| # Ensure we only select columns that actually exist in the DataFrame | |
| potential_feature_cols = [col for col in df.columns if col.startswith(("entity_", "relation_"))] | |
| feature_cols = [col for col in potential_feature_cols if col in df.columns] | |
| if not feature_cols: | |
| logger.error("No component features found in DataFrame. Column names should start with 'entity_' or 'relation_'.") | |
| return None, {}, [] | |
| logger.info(f"Found {len(feature_cols)} feature columns for analysis") | |
| # Check if we have enough data for meaningful analysis | |
| if len(df) < 2: | |
| logger.error("Not enough data points for analysis (need at least 2 rows).") | |
| return None, {}, [] | |
| # Prepare X and y | |
| X = df[feature_cols] | |
| y = df['perturbation'] | |
| # Check if target variable has any variance | |
| if y.std() == 0: | |
| logger.warning("Target variable 'perturbation' has no variance. Feature importance will be 0 for all features.") | |
| # Return a dictionary of zeros for all features and the feature list | |
| return None, {feature: 0.0 for feature in feature_cols}, feature_cols | |
| try: | |
| # 1. Create and train the Random Forest model (still used for metrics and as a backup) | |
| rf_model = RandomForestRegressor(n_estimators=n_estimators, random_state=random_state) | |
| rf_model.fit(X, y) | |
| # 2. Fit a linear model for effect estimation with direction | |
| linear_model = LinearRegression() | |
| linear_model.fit(X, y) | |
| # Get coefficients (these include both magnitude and direction) | |
| coefficients = linear_model.coef_ | |
| # 3. Use linear coefficients directly as our importance scores | |
| feature_importance = {} | |
| for i, feature in enumerate(feature_cols): | |
| feature_importance[feature] = coefficients[i] | |
| # Sort by absolute importance (magnitude) | |
| feature_importance = dict(sorted(feature_importance.items(), key=lambda x: abs(x[1]), reverse=True)) | |
| return rf_model, feature_importance, feature_cols | |
| except Exception as e: | |
| logger.error(f"Error during model training: {e}") | |
| return None, {feature: 0.0 for feature in feature_cols}, feature_cols | |
| def print_feature_importance(feature_importance: Dict[str, float], top_n: int = 10) -> None: | |
| """ | |
| Prints the feature importance values with signs (positive/negative influence). | |
| Args: | |
| feature_importance: Dictionary mapping feature names to importance values | |
| top_n: Number of top features to show | |
| """ | |
| print(f"\nTop {min(top_n, len(feature_importance))} Components by Influence:") | |
| print("=" * 50) | |
| print(f"{'Rank':<5}{'Component':<30}{'Importance':<15}{'Direction':<10}") | |
| print("-" * 50) | |
| # Sort by absolute importance | |
| sorted_features = sorted(feature_importance.items(), key=lambda x: abs(x[1]), reverse=True) | |
| for i, (feature, importance) in enumerate(sorted_features[:min(top_n, len(feature_importance))], 1): | |
| direction = "Positive" if importance >= 0 else "Negative" | |
| print(f"{i:<5}{feature:<30}{abs(importance):.6f} {direction}") | |
| # Save to CSV for further analysis | |
| output_path = os.path.join(os.path.dirname(__file__), 'component_influence_rankings.csv') | |
| pd.DataFrame({ | |
| 'Component': [item[0] for item in sorted_features], | |
| 'Importance': [abs(item[1]) for item in sorted_features], | |
| 'Direction': ["Positive" if item[1] >= 0 else "Negative" for item in sorted_features] | |
| }).to_csv(output_path, index=False) | |
| logger.info(f"Component rankings saved to {output_path}") | |
| def evaluate_model(model: Optional[RandomForestRegressor], X: pd.DataFrame, y: pd.Series) -> Dict[str, float]: | |
| """ | |
| Evaluates the model performance. | |
| Args: | |
| model: Trained RandomForestRegressor model (or None) | |
| X: Feature DataFrame | |
| y: Target series | |
| Returns: | |
| Dictionary of evaluation metrics | |
| """ | |
| if model is None: | |
| return { | |
| 'mse': 0.0, | |
| 'rmse': 0.0, | |
| 'r2': 1.0 if y.std() == 0 else 0.0 | |
| } | |
| try: | |
| y_pred = model.predict(X) | |
| mse = mean_squared_error(y, y_pred) | |
| r2 = r2_score(y, y_pred) | |
| return { | |
| 'mse': mse, | |
| 'rmse': np.sqrt(mse), | |
| 'r2': r2 | |
| } | |
| except Exception as e: | |
| logger.error(f"Error during model evaluation: {e}") | |
| return { | |
| 'mse': 0.0, | |
| 'rmse': 0.0, | |
| 'r2': 0.0 | |
| } | |
| def identify_key_components(feature_importance: Dict[str, float], | |
| threshold: float = 0.01) -> List[str]: | |
| """ | |
| Identifies key components that have absolute importance above the threshold. | |
| Args: | |
| feature_importance: Dictionary mapping feature names to importance values | |
| threshold: Minimum absolute importance value to be considered a key component | |
| Returns: | |
| List of key component names | |
| """ | |
| return [feature for feature, importance in feature_importance.items() | |
| if abs(importance) >= threshold] | |
| def print_component_groups(df: pd.DataFrame, feature_importance: Dict[str, float]) -> None: | |
| """ | |
| Prints component influence by type, handling both positive and negative values. | |
| Args: | |
| df: Original DataFrame | |
| feature_importance: Feature importance dictionary with signed values | |
| """ | |
| if not feature_importance: | |
| print("\nNo feature importance values available for group analysis.") | |
| return | |
| # Extract entity and relation features | |
| entity_features = [f for f in feature_importance.keys() if f.startswith('entity_')] | |
| relation_features = [f for f in feature_importance.keys() if f.startswith('relation_')] | |
| # Calculate group importances (using absolute values) | |
| entity_importance = sum(abs(feature_importance[f]) for f in entity_features) | |
| relation_importance = sum(abs(feature_importance[f]) for f in relation_features) | |
| total_importance = sum(abs(value) for value in feature_importance.values()) | |
| # Count positive and negative components | |
| pos_entities = sum(1 for f in entity_features if feature_importance[f] > 0) | |
| neg_entities = sum(1 for f in entity_features if feature_importance[f] < 0) | |
| pos_relations = sum(1 for f in relation_features if feature_importance[f] > 0) | |
| neg_relations = sum(1 for f in relation_features if feature_importance[f] < 0) | |
| print("\nComponent Group Influence:") | |
| print("=" * 70) | |
| print(f"{'Group':<20}{'Abs Importance':<15}{'Percentage':<10}{'Positive':<10}{'Negative':<10}") | |
| print("-" * 70) | |
| if total_importance > 0: | |
| entity_percentage = (entity_importance/total_importance*100) if total_importance > 0 else 0 | |
| relation_percentage = (relation_importance/total_importance*100) if total_importance > 0 else 0 | |
| print(f"{'Entities':<20}{entity_importance:.6f}{'%.2f%%' % entity_percentage:<10}{pos_entities:<10}{neg_entities:<10}") | |
| print(f"{'Relations':<20}{relation_importance:.6f}{'%.2f%%' % relation_percentage:<10}{pos_relations:<10}{neg_relations:<10}") | |
| else: | |
| print("No importance values available for analysis.") | |
| def main(): | |
| """Main function to run the component influence analysis.""" | |
| import argparse | |
| parser = argparse.ArgumentParser(description='Analyze component influence on perturbation scores') | |
| parser.add_argument('--input', '-i', required=True, help='Path to the knowledge graph JSON file') | |
| parser.add_argument('--output', '-o', help='Path to save the output DataFrame (CSV format)') | |
| args = parser.parse_args() | |
| print("\n=== Component Influence Analysis ===") | |
| print(f"Input file: {args.input}") | |
| print(f"Output file: {args.output or 'Not specified'}") | |
| # Create DataFrame using the function from create_component_influence_dataframe.py | |
| print("\nCreating DataFrame from knowledge graph...") | |
| df = create_component_influence_dataframe(args.input) | |
| if df is None or df.empty: | |
| logger.error("Failed to create or empty DataFrame. Cannot proceed with analysis.") | |
| return | |
| # Print basic DataFrame info | |
| print(f"\nDataFrame info:") | |
| print(f"Rows: {len(df)}") | |
| entity_features = [col for col in df.columns if col.startswith("entity_")] | |
| relation_features = [col for col in df.columns if col.startswith("relation_")] | |
| print(f"Entity features: {len(entity_features)}") | |
| print(f"Relation features: {len(relation_features)}") | |
| print(f"Other columns: {', '.join([col for col in df.columns if not (col.startswith('entity_') or col.startswith('relation_'))])}") | |
| # Check if we have any variance in perturbation scores | |
| if df['perturbation'].std() == 0: | |
| logger.warning("All perturbation scores are identical. This might lead to uninformative results.") | |
| print("\nWARNING: All perturbation scores are identical (value: %.2f). Results may not be meaningful." % df['perturbation'].iloc[0]) | |
| else: | |
| print(f"\nPerturbation score distribution:") | |
| print(f"Min: {df['perturbation'].min():.2f}, Max: {df['perturbation'].max():.2f}") | |
| print(f"Mean: {df['perturbation'].mean():.2f}, Std: {df['perturbation'].std():.2f}") | |
| # Run analysis | |
| print("\nRunning component influence analysis...") | |
| model, feature_importance, feature_cols = analyze_component_influence(df) | |
| # Print feature importance | |
| print_feature_importance(feature_importance) | |
| # Identify key components | |
| print("\nIdentifying key components...") | |
| key_components = identify_key_components(feature_importance) | |
| print(f"Identified {len(key_components)} key components (importance >= 0.01)") | |
| # Print component groups | |
| print("\nAnalyzing component groups...") | |
| print_component_groups(df, feature_importance) | |
| # Evaluate model | |
| print("\nEvaluating model performance...") | |
| metrics = evaluate_model(model, df[feature_cols], df['perturbation']) | |
| print("\nModel Evaluation Metrics:") | |
| print("=" * 50) | |
| for metric, value in metrics.items(): | |
| print(f"{metric.upper()}: {value:.6f}") | |
| # Save full DataFrame with importance values for reference | |
| if args.output: | |
| result_df = df.copy() | |
| for feature, importance in feature_importance.items(): | |
| result_df[f'importance_{feature}'] = importance | |
| result_df.to_csv(args.output) | |
| logger.info(f"Full analysis results saved to {args.output}") | |
| print("\nAnalysis complete. CSV files with detailed results have been saved.") | |
| if __name__ == "__main__": | |
| main() |