# ============================================ # CLASS 8: CORRELATION AND MULTICOLLINEARITY ANALYSIS # ============================================ import os import traceback from typing import Any, Dict, List, Optional from venv import logger from config.config import Config import numpy as np import pandas as pd class CorrelationAnalyzer: """Class for comprehensive correlation and multicollinearity analysis""" def __init__(self, config: Config): """ Initialise the analyser Parameters: ----------- config : Config Experiment configuration """ self.config = config self.correlation_matrices = {} self.high_correlation_pairs = {} self.multicollinearity_info = {} self.vif_scores = {} def analyze( self, data: pd.DataFrame, target_col: Optional[str] = None, threshold: float = 0.8, detailed: bool = True, **kwargs ) -> pd.DataFrame: """ Analyse correlations in the data Parameters: ----------- data : pd.DataFrame Input data target_col : str, optional Target variable threshold : float Threshold for identifying high correlations detailed : bool Whether to perform detailed analysis **kwargs : dict Additional parameters Returns: -------- pd.DataFrame Correlation matrix """ logger.info("\n" + "="*80) logger.info("CORRELATION AND MULTICOLLINEARITY ANALYSIS") logger.info("="*80) target_col = target_col or self.config.target_column try: # 1. Calculate correlation matrix corr_matrix = self._compute_correlations(data, target_col) if corr_matrix.empty: logger.warning("Correlation matrix is empty") return pd.DataFrame() # 2. Identify high correlations high_correlations = self._detect_high_correlations(corr_matrix, threshold) self.high_correlation_pairs['pearson'] = high_correlations # 3. Analyse correlations with target variable target_correlations = [] if target_col in corr_matrix.columns: target_correlations = self._get_target_correlations(corr_matrix, target_col) # 4. Analyse multicollinearity (VIF) vif_results = self._compute_vif_scores(data) # 5. Detailed analysis if required if detailed: self._detailed_correlation_analysis(data, corr_matrix, target_col) # 6. Visualisation if self.config.save_plots: self._plot_correlation_analysis(data, corr_matrix, target_col, high_correlations, vif_results) # 7. Output results self._log_analysis_results(corr_matrix, high_correlations, target_correlations, vif_results) return corr_matrix except Exception as e: logger.error(f"Error in correlation analysis: {e}") logger.error(traceback.format_exc()) return pd.DataFrame() def _compute_correlations( self, data: pd.DataFrame, target_col: str ) -> pd.DataFrame: """Calculate correlation matrix""" logger.info("Calculating correlation matrix...") # Select only numeric columns numeric_data = data.select_dtypes(include=[np.number]) # Remove constant columns numeric_data = numeric_data.loc[:, numeric_data.nunique() > 1] if numeric_data.shape[1] < 2: logger.warning("Insufficient numeric features for analysis") return pd.DataFrame() # Remove missing values numeric_data_clean = numeric_data.dropna() if len(numeric_data_clean) < 10: logger.warning("Insufficient data after cleaning") return pd.DataFrame() # Calculate Pearson correlation try: corr_matrix = numeric_data_clean.corr(method='pearson') self.correlation_matrices['pearson'] = corr_matrix logger.info(f"āœ“ Correlation matrix calculated: {corr_matrix.shape}") return corr_matrix except Exception as e: logger.error(f"Error calculating correlation: {e}") return pd.DataFrame() def _detect_high_correlations( self, corr_matrix: pd.DataFrame, threshold: float = 0.8 ) -> List[Dict[str, Any]]: """Detect high correlations""" high_correlations = [] if corr_matrix.empty: return high_correlations # Use upper triangle of matrix upper_triangle = corr_matrix.where( np.triu(np.ones(corr_matrix.shape), k=1).astype(bool) ) # Find pairs with correlation above threshold for col in upper_triangle.columns: if col in upper_triangle: high_corr_series = upper_triangle[col][abs(upper_triangle[col]) > threshold] for row_idx, correlation in high_corr_series.items(): if not pd.isna(correlation): high_correlations.append({ 'feature1': row_idx, 'feature2': col, 'correlation': float(correlation), 'abs_correlation': abs(float(correlation)) }) # Sort by absolute correlation value high_correlations.sort(key=lambda x: x['abs_correlation'], reverse=True) logger.info(f"High correlations detected (> {threshold}): {len(high_correlations)}") return high_correlations def _get_target_correlations( self, corr_matrix: pd.DataFrame, target_col: str ) -> List[Dict[str, Any]]: """Get correlations with target variable""" target_correlations = [] if target_col not in corr_matrix.columns: return target_correlations # Extract correlations with target variable target_corr_series = corr_matrix[target_col] for feature, correlation in target_corr_series.items(): if feature != target_col and not pd.isna(correlation): target_correlations.append({ 'feature': feature, 'correlation': float(correlation), 'abs_correlation': abs(float(correlation)), 'direction': 'positive' if correlation > 0 else 'negative' }) # Sort by absolute value target_correlations.sort(key=lambda x: x['abs_correlation'], reverse=True) logger.info(f"Correlations with target variable calculated: {len(target_correlations)}") return target_correlations def _compute_vif_scores(self, data: pd.DataFrame) -> Dict[str, Any]: """Calculate VIF (Variance Inflation Factor)""" logger.info("Analysing multicollinearity (VIF)...") vif_results = { 'scores': {}, 'issues': [], 'summary': { 'critical': 0, 'high': 0, 'medium': 0, 'low': 0 } } try: from statsmodels.stats.outliers_influence import variance_inflation_factor import statsmodels.api as sm # Prepare data numeric_data = data.select_dtypes(include=[np.number]) numeric_data = numeric_data.loc[:, numeric_data.nunique() > 1] # Remove missing and infinite values clean_data = numeric_data.replace([np.inf, -np.inf], np.nan).dropna() if clean_data.shape[0] < 10 or clean_data.shape[1] < 2: logger.warning("Insufficient data for VIF analysis") return vif_results # Add constant X = sm.add_constant(clean_data, has_constant='add') # Calculate VIF for each feature vif_scores = {} for i, column in enumerate(X.columns): if column == 'const': continue try: vif = variance_inflation_factor(X.values, i) # Handle extreme values if np.isinf(vif) or vif > 1e6: vif = 1e6 vif_scores[column] = float(vif) # Classify by severity if vif > 100: vif_results['summary']['critical'] += 1 vif_results['issues'].append({ 'feature': column, 'vif': float(vif), 'severity': 'critical', 'recommendation': 'Remove feature' }) elif vif > 10: vif_results['summary']['high'] += 1 vif_results['issues'].append({ 'feature': column, 'vif': float(vif), 'severity': 'high', 'recommendation': 'Consider removal' }) elif vif > 5: vif_results['summary']['medium'] += 1 else: vif_results['summary']['low'] += 1 except Exception as e: logger.warning(f"VIF error for {column}: {e}") vif_scores[column] = np.nan vif_results['scores'] = vif_scores self.vif_scores = vif_scores logger.info(f"āœ“ VIF analysis completed. Critical features: {vif_results['summary']['critical']}") except ImportError: logger.warning("statsmodels not installed, skipping VIF analysis") except Exception as e: logger.error(f"VIF analysis error: {e}") return vif_results def _detailed_correlation_analysis( self, data: pd.DataFrame, corr_matrix: pd.DataFrame, target_col: str ) -> None: """Detailed correlation analysis""" # Analyse correlation clusters if not corr_matrix.empty and corr_matrix.shape[0] > 3: try: # Use clustering to group correlated features from scipy.cluster.hierarchy import linkage, dendrogram, fcluster from scipy.spatial.distance import squareform # Convert correlations to distances distance_matrix = 1 - abs(corr_matrix) np.fill_diagonal(distance_matrix.values, 0) # Clustering condensed_dist = squareform(distance_matrix) Z = linkage(condensed_dist, method='average') # Determine clusters clusters = fcluster(Z, t=0.5, criterion='distance') # Group features by cluster feature_clusters = {} for idx, cluster_id in enumerate(clusters): feature = corr_matrix.columns[idx] if cluster_id not in feature_clusters: feature_clusters[cluster_id] = [] feature_clusters[cluster_id].append(feature) # Save cluster information self.multicollinearity_info['correlation_clusters'] = feature_clusters logger.info(f"Correlated feature clusters detected: {len(feature_clusters)}") except Exception as e: logger.debug(f"Cluster analysis failed: {e}") def _plot_correlation_analysis( self, data: pd.DataFrame, corr_matrix: pd.DataFrame, target_col: str, high_correlations: List[Dict[str, Any]], vif_results: Dict[str, Any] ) -> None: """Visualise correlation analysis""" try: import matplotlib.pyplot as plt import seaborn as sns from matplotlib import rcParams # Style settings plt.style.use('seaborn-v0_8-darkgrid') rcParams.update({ 'figure.figsize': (12, 8), 'font.size': 10, 'axes.titlesize': 14, 'axes.labelsize': 12 }) # Create directory plots_dir = os.path.join(self.config.results_dir, 'plots', 'correlations') os.makedirs(plots_dir, exist_ok=True) # 1. Correlation matrix heatmap if not corr_matrix.empty and corr_matrix.shape[0] > 1: fig, ax = plt.subplots(figsize=(14, 12)) mask = np.triu(np.ones_like(corr_matrix, dtype=bool)) sns.heatmap( corr_matrix, mask=mask, annot=True, fmt='.2f', cmap='coolwarm', center=0, square=True, linewidths=0.5, cbar_kws={"shrink": 0.8}, ax=ax ) ax.set_title('Correlation Matrix (Pearson)', fontweight='bold') plt.tight_layout() plt.savefig(os.path.join(plots_dir, 'correlation_matrix.png'), dpi=150, bbox_inches='tight') plt.close() # 2. Target variable correlations if target_col in corr_matrix.columns: target_corrs = corr_matrix[target_col].drop(target_col, errors='ignore') if not target_corrs.empty: fig, ax = plt.subplots(figsize=(10, 8)) top_corrs = target_corrs.abs().sort_values(ascending=True).tail(20) colors = ['red' if target_corrs[feat] < 0 else 'blue' for feat in top_corrs.index] ax.barh(range(len(top_corrs)), top_corrs.values, color=colors) ax.set_yticks(range(len(top_corrs))) ax.set_yticklabels(top_corrs.index) ax.set_xlabel('Absolute correlation') ax.set_title(f'Top-20 correlations with {target_col}', fontweight='bold') ax.grid(True, alpha=0.3, axis='x') plt.tight_layout() plt.savefig(os.path.join(plots_dir, 'target_correlations.png'), dpi=150, bbox_inches='tight') plt.close() # 3. VIF scores plot if vif_results['scores']: valid_scores = {k: v for k, v in vif_results['scores'].items() if not pd.isna(v)} if valid_scores: fig, ax = plt.subplots(figsize=(12, 8)) sorted_scores = dict(sorted(valid_scores.items(), key=lambda x: x[1], reverse=True)[:25]) colors = [] for vif in sorted_scores.values(): if vif > 100: colors.append('red') elif vif > 10: colors.append('orange') elif vif > 5: colors.append('yellow') else: colors.append('green') bars = ax.barh(list(sorted_scores.keys()), list(sorted_scores.values()), color=colors, edgecolor='black') ax.set_xlabel('VIF Score') ax.set_title('VIF Scores (multicollinearity)', fontweight='bold') ax.axvline(x=5, color='yellow', linestyle='--', alpha=0.7) ax.axvline(x=10, color='orange', linestyle='--', alpha=0.7) ax.axvline(x=100, color='red', linestyle='--', alpha=0.7) ax.grid(True, alpha=0.3, axis='x') plt.tight_layout() plt.savefig(os.path.join(plots_dir, 'vif_scores.png'), dpi=150, bbox_inches='tight') plt.close() # 4. High correlations plot if high_correlations: fig, ax = plt.subplots(figsize=(12, 8)) # Limit number for display display_corrs = high_correlations[:15] # Create labels for feature pairs labels = [f"{corr['feature1']} ↔ {corr['feature2']}" for corr in display_corrs] values = [corr['correlation'] for corr in display_corrs] colors = ['red' if v < 0 else 'blue' for v in values] y_pos = np.arange(len(display_corrs)) ax.barh(y_pos, values, color=colors) ax.set_yticks(y_pos) ax.set_yticklabels(labels, fontsize=9) ax.invert_yaxis() ax.set_xlabel('Correlation') ax.set_title('High correlations (> 0.8)', fontweight='bold') ax.grid(True, alpha=0.3, axis='x') plt.tight_layout() plt.savefig(os.path.join(plots_dir, 'high_correlations.png'), dpi=150, bbox_inches='tight') plt.close() logger.info(f"Visualisations saved to {plots_dir}") except Exception as e: logger.warning(f"Error creating visualisations: {e}") def _log_analysis_results( self, corr_matrix: pd.DataFrame, high_correlations: List[Dict[str, Any]], target_correlations: List[Dict[str, Any]], vif_results: Dict[str, Any] ) -> None: """Log analysis results""" logger.info("\n" + "="*80) logger.info("CORRELATION AND MULTICOLLINEARITY ANALYSIS REPORT") logger.info("="*80) # General information logger.info(f"\nšŸ“Š GENERAL INFORMATION:") logger.info(f" Correlation matrix size: {corr_matrix.shape}") logger.info(f" Total features: {len(corr_matrix.columns)}") # High correlations if high_correlations: logger.info(f"\n⚠ HIGH CORRELATIONS (|r| > 0.8): {len(high_correlations)}") logger.info(" " + "-" * 60) for i, corr in enumerate(high_correlations[:10]): sign = "🟄" if corr['correlation'] < 0 else "🟩" logger.info(f" {i+1:2d}. {sign} {corr['feature1']:25s} ↔ {corr['feature2']:25s}: {corr['correlation']:7.4f}") if len(high_correlations) > 10: logger.info(f" ... and {len(high_correlations) - 10} more pairs") # Target variable correlations if target_correlations: logger.info(f"\nšŸŽÆ CORRELATIONS WITH TARGET VARIABLE:") logger.info(" " + "-" * 60) for i, corr in enumerate(target_correlations[:10]): direction = "↓" if corr['correlation'] < 0 else "↑" logger.info(f" {i+1:2d}. {direction} {corr['feature']:35s}: {corr['correlation']:7.4f}") # Multicollinearity analysis if vif_results['scores']: logger.info(f"\nšŸ“ˆ MULTICOLLINEARITY ANALYSIS (VIF):") logger.info(" " + "-" * 60) logger.info(f" Critical (VIF > 100): {vif_results['summary']['critical']}") logger.info(f" High (10 < VIF ≤ 100): {vif_results['summary']['high']}") logger.info(f" Medium (5 < VIF ≤ 10): {vif_results['summary']['medium']}") logger.info(f" Low (VIF ≤ 5): {vif_results['summary']['low']}") # Top problematic features if vif_results['issues']: logger.info(f"\nšŸ”“ PROBLEMATIC FEATURES (VIF > 10):") for issue in vif_results['issues'][:10]: logger.info(f" • {issue['feature']:35s}: VIF = {issue['vif']:7.1f} ({issue['severity']})") logger.info("\n" + "="*80) logger.info("RECOMMENDATIONS:") logger.info("="*80) # Generate recommendations recommendations = [] if len(high_correlations) > 20: recommendations.append("1. Remove highly correlated features (correlation method)") if vif_results['summary']['critical'] > 0: recommendations.append("2. Remove features with critical VIF (>100)") if vif_results['summary']['high'] > 5: recommendations.append("3. Consider removing features with VIF > 10") if not recommendations: recommendations.append("1. Data in good condition, no serious issues detected") recommendations.append("2. Proceed to modelling") for i, rec in enumerate(recommendations, 1): logger.info(f" {rec}") logger.info("\n" + "="*80) def remove_highly_correlated( self, data: pd.DataFrame, threshold: float = 0.85, method: str = 'variance', keep_target: bool = True, keep_features: List[str] = None ) -> pd.DataFrame: """ Remove highly correlated features Parameters: ----------- data : pd.DataFrame Source data threshold : float Correlation threshold for removal method : str Feature selection method for removal: 'variance', 'random', 'importance' keep_target : bool Whether to keep target variable keep_features : List[str], optional Features to keep Returns: -------- pd.DataFrame Data after removing highly correlated features """ logger.info("\n" + "="*80) logger.info("REMOVING HIGHLY CORRELATED FEATURES") logger.info("="*80) data_clean = data.copy() if 'pearson' not in self.correlation_matrices: logger.warning("Correlation matrix not calculated, run analyze() first") return data_clean corr_matrix = self.correlation_matrices['pearson'] # Features to keep features_to_keep = set() if keep_target and self.config.target_column in data_clean.columns: features_to_keep.add(self.config.target_column) if keep_features: for feat in keep_features: if feat in data_clean.columns: features_to_keep.add(feat) # Temporal features (usually important for time series) temporal_patterns = ['year', 'month', 'day', 'week', 'quarter', 'hour', 'minute', 'second', 'sin', 'cos'] for col in data_clean.columns: if any(pattern in col.lower() for pattern in temporal_patterns): features_to_keep.add(col) # Find highly correlated pairs upper_triangle = corr_matrix.where( np.triu(np.ones(corr_matrix.shape), k=1).astype(bool) ) # Collect highly correlated features correlated_features = set() for col in upper_triangle.columns: if col in features_to_keep: continue high_corr = upper_triangle[col][abs(upper_triangle[col]) > threshold] for row_idx, corr_value in high_corr.items(): if not pd.isna(corr_value) and row_idx not in features_to_keep: # Select which feature to remove if method == 'variance': # Remove the one with lower variance var_col = data_clean[col].var() var_row = data_clean[row_idx].var() feature_to_remove = col if var_col < var_row else row_idx elif method == 'importance': # Remove the one with lower correlation to target variable if self.config.target_column in corr_matrix.columns: corr_col_target = abs(corr_matrix.loc[col, self.config.target_column]) corr_row_target = abs(corr_matrix.loc[row_idx, self.config.target_column]) feature_to_remove = col if corr_col_target < corr_row_target else row_idx else: # If no target, remove randomly feature_to_remove = np.random.choice([col, row_idx]) else: # Remove randomly feature_to_remove = np.random.choice([col, row_idx]) correlated_features.add(feature_to_remove) # Remove features features_to_remove = list(correlated_features) if features_to_remove: data_clean = data_clean.drop(columns=features_to_remove) logger.info(f"\nšŸ“Š REMOVAL RESULTS:") logger.info(f" Initial feature count: {len(data.columns)}") logger.info(f" Features removed: {len(features_to_remove)}") logger.info(f" Final feature count: {len(data_clean.columns)}") logger.info(f" Retained: {len(data_clean.columns)/len(data.columns)*100:.1f}%") if features_to_remove: logger.info(f"\nšŸ—‘ļø REMOVED FEATURES:") for i, feat in enumerate(sorted(features_to_remove)[:20]): logger.info(f" {i+1:2d}. {feat}") if len(features_to_remove) > 20: logger.info(f" ... and {len(features_to_remove) - 20} more features") else: logger.info("āœ“ No highly correlated features detected, all features retained") logger.info("="*80) return data_clean def get_report(self) -> Dict[str, Any]: """Get analysis report""" report = { "correlation_matrix_shape": None, "high_correlation_count": 0, "vif_summary": {}, "target_correlation_count": 0 } if 'pearson' in self.correlation_matrices: report["correlation_matrix_shape"] = self.correlation_matrices['pearson'].shape if 'pearson' in self.high_correlation_pairs: report["high_correlation_count"] = len(self.high_correlation_pairs['pearson']) if self.vif_scores: report["vif_summary"] = self.vif_scores.get('summary', {}) return report