TimeFlowPro / correlations /correlation_analyzer.py
ArabovMK's picture
Update all files
d8f69a9
# ============================================
# CLASS 8: CORRELATION AND MULTICOLLINEARITY ANALYSIS
# ============================================
import os
import traceback
from typing import Any, Dict, List, Optional
from venv import logger
from config.config import Config
import numpy as np
import pandas as pd
class CorrelationAnalyzer:
"""Class for comprehensive correlation and multicollinearity analysis"""
def __init__(self, config: Config):
"""
Initialise the analyser
Parameters:
-----------
config : Config
Experiment configuration
"""
self.config = config
self.correlation_matrices = {}
self.high_correlation_pairs = {}
self.multicollinearity_info = {}
self.vif_scores = {}
def analyze(
self,
data: pd.DataFrame,
target_col: Optional[str] = None,
threshold: float = 0.8,
detailed: bool = True,
**kwargs
) -> pd.DataFrame:
"""
Analyse correlations in the data
Parameters:
-----------
data : pd.DataFrame
Input data
target_col : str, optional
Target variable
threshold : float
Threshold for identifying high correlations
detailed : bool
Whether to perform detailed analysis
**kwargs : dict
Additional parameters
Returns:
--------
pd.DataFrame
Correlation matrix
"""
logger.info("\n" + "="*80)
logger.info("CORRELATION AND MULTICOLLINEARITY ANALYSIS")
logger.info("="*80)
target_col = target_col or self.config.target_column
try:
# 1. Calculate correlation matrix
corr_matrix = self._compute_correlations(data, target_col)
if corr_matrix.empty:
logger.warning("Correlation matrix is empty")
return pd.DataFrame()
# 2. Identify high correlations
high_correlations = self._detect_high_correlations(corr_matrix, threshold)
self.high_correlation_pairs['pearson'] = high_correlations
# 3. Analyse correlations with target variable
target_correlations = []
if target_col in corr_matrix.columns:
target_correlations = self._get_target_correlations(corr_matrix, target_col)
# 4. Analyse multicollinearity (VIF)
vif_results = self._compute_vif_scores(data)
# 5. Detailed analysis if required
if detailed:
self._detailed_correlation_analysis(data, corr_matrix, target_col)
# 6. Visualisation
if self.config.save_plots:
self._plot_correlation_analysis(data, corr_matrix, target_col, high_correlations, vif_results)
# 7. Output results
self._log_analysis_results(corr_matrix, high_correlations, target_correlations, vif_results)
return corr_matrix
except Exception as e:
logger.error(f"Error in correlation analysis: {e}")
logger.error(traceback.format_exc())
return pd.DataFrame()
def _compute_correlations(
self,
data: pd.DataFrame,
target_col: str
) -> pd.DataFrame:
"""Calculate correlation matrix"""
logger.info("Calculating correlation matrix...")
# Select only numeric columns
numeric_data = data.select_dtypes(include=[np.number])
# Remove constant columns
numeric_data = numeric_data.loc[:, numeric_data.nunique() > 1]
if numeric_data.shape[1] < 2:
logger.warning("Insufficient numeric features for analysis")
return pd.DataFrame()
# Remove missing values
numeric_data_clean = numeric_data.dropna()
if len(numeric_data_clean) < 10:
logger.warning("Insufficient data after cleaning")
return pd.DataFrame()
# Calculate Pearson correlation
try:
corr_matrix = numeric_data_clean.corr(method='pearson')
self.correlation_matrices['pearson'] = corr_matrix
logger.info(f"✓ Correlation matrix calculated: {corr_matrix.shape}")
return corr_matrix
except Exception as e:
logger.error(f"Error calculating correlation: {e}")
return pd.DataFrame()
def _detect_high_correlations(
self,
corr_matrix: pd.DataFrame,
threshold: float = 0.8
) -> List[Dict[str, Any]]:
"""Detect high correlations"""
high_correlations = []
if corr_matrix.empty:
return high_correlations
# Use upper triangle of matrix
upper_triangle = corr_matrix.where(
np.triu(np.ones(corr_matrix.shape), k=1).astype(bool)
)
# Find pairs with correlation above threshold
for col in upper_triangle.columns:
if col in upper_triangle:
high_corr_series = upper_triangle[col][abs(upper_triangle[col]) > threshold]
for row_idx, correlation in high_corr_series.items():
if not pd.isna(correlation):
high_correlations.append({
'feature1': row_idx,
'feature2': col,
'correlation': float(correlation),
'abs_correlation': abs(float(correlation))
})
# Sort by absolute correlation value
high_correlations.sort(key=lambda x: x['abs_correlation'], reverse=True)
logger.info(f"High correlations detected (> {threshold}): {len(high_correlations)}")
return high_correlations
def _get_target_correlations(
self,
corr_matrix: pd.DataFrame,
target_col: str
) -> List[Dict[str, Any]]:
"""Get correlations with target variable"""
target_correlations = []
if target_col not in corr_matrix.columns:
return target_correlations
# Extract correlations with target variable
target_corr_series = corr_matrix[target_col]
for feature, correlation in target_corr_series.items():
if feature != target_col and not pd.isna(correlation):
target_correlations.append({
'feature': feature,
'correlation': float(correlation),
'abs_correlation': abs(float(correlation)),
'direction': 'positive' if correlation > 0 else 'negative'
})
# Sort by absolute value
target_correlations.sort(key=lambda x: x['abs_correlation'], reverse=True)
logger.info(f"Correlations with target variable calculated: {len(target_correlations)}")
return target_correlations
def _compute_vif_scores(self, data: pd.DataFrame) -> Dict[str, Any]:
"""Calculate VIF (Variance Inflation Factor)"""
logger.info("Analysing multicollinearity (VIF)...")
vif_results = {
'scores': {},
'issues': [],
'summary': {
'critical': 0,
'high': 0,
'medium': 0,
'low': 0
}
}
try:
from statsmodels.stats.outliers_influence import variance_inflation_factor
import statsmodels.api as sm
# Prepare data
numeric_data = data.select_dtypes(include=[np.number])
numeric_data = numeric_data.loc[:, numeric_data.nunique() > 1]
# Remove missing and infinite values
clean_data = numeric_data.replace([np.inf, -np.inf], np.nan).dropna()
if clean_data.shape[0] < 10 or clean_data.shape[1] < 2:
logger.warning("Insufficient data for VIF analysis")
return vif_results
# Add constant
X = sm.add_constant(clean_data, has_constant='add')
# Calculate VIF for each feature
vif_scores = {}
for i, column in enumerate(X.columns):
if column == 'const':
continue
try:
vif = variance_inflation_factor(X.values, i)
# Handle extreme values
if np.isinf(vif) or vif > 1e6:
vif = 1e6
vif_scores[column] = float(vif)
# Classify by severity
if vif > 100:
vif_results['summary']['critical'] += 1
vif_results['issues'].append({
'feature': column,
'vif': float(vif),
'severity': 'critical',
'recommendation': 'Remove feature'
})
elif vif > 10:
vif_results['summary']['high'] += 1
vif_results['issues'].append({
'feature': column,
'vif': float(vif),
'severity': 'high',
'recommendation': 'Consider removal'
})
elif vif > 5:
vif_results['summary']['medium'] += 1
else:
vif_results['summary']['low'] += 1
except Exception as e:
logger.warning(f"VIF error for {column}: {e}")
vif_scores[column] = np.nan
vif_results['scores'] = vif_scores
self.vif_scores = vif_scores
logger.info(f"✓ VIF analysis completed. Critical features: {vif_results['summary']['critical']}")
except ImportError:
logger.warning("statsmodels not installed, skipping VIF analysis")
except Exception as e:
logger.error(f"VIF analysis error: {e}")
return vif_results
def _detailed_correlation_analysis(
self,
data: pd.DataFrame,
corr_matrix: pd.DataFrame,
target_col: str
) -> None:
"""Detailed correlation analysis"""
# Analyse correlation clusters
if not corr_matrix.empty and corr_matrix.shape[0] > 3:
try:
# Use clustering to group correlated features
from scipy.cluster.hierarchy import linkage, dendrogram, fcluster
from scipy.spatial.distance import squareform
# Convert correlations to distances
distance_matrix = 1 - abs(corr_matrix)
np.fill_diagonal(distance_matrix.values, 0)
# Clustering
condensed_dist = squareform(distance_matrix)
Z = linkage(condensed_dist, method='average')
# Determine clusters
clusters = fcluster(Z, t=0.5, criterion='distance')
# Group features by cluster
feature_clusters = {}
for idx, cluster_id in enumerate(clusters):
feature = corr_matrix.columns[idx]
if cluster_id not in feature_clusters:
feature_clusters[cluster_id] = []
feature_clusters[cluster_id].append(feature)
# Save cluster information
self.multicollinearity_info['correlation_clusters'] = feature_clusters
logger.info(f"Correlated feature clusters detected: {len(feature_clusters)}")
except Exception as e:
logger.debug(f"Cluster analysis failed: {e}")
def _plot_correlation_analysis(
self,
data: pd.DataFrame,
corr_matrix: pd.DataFrame,
target_col: str,
high_correlations: List[Dict[str, Any]],
vif_results: Dict[str, Any]
) -> None:
"""Visualise correlation analysis"""
try:
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import rcParams
# Style settings
plt.style.use('seaborn-v0_8-darkgrid')
rcParams.update({
'figure.figsize': (12, 8),
'font.size': 10,
'axes.titlesize': 14,
'axes.labelsize': 12
})
# Create directory
plots_dir = os.path.join(self.config.results_dir, 'plots', 'correlations')
os.makedirs(plots_dir, exist_ok=True)
# 1. Correlation matrix heatmap
if not corr_matrix.empty and corr_matrix.shape[0] > 1:
fig, ax = plt.subplots(figsize=(14, 12))
mask = np.triu(np.ones_like(corr_matrix, dtype=bool))
sns.heatmap(
corr_matrix,
mask=mask,
annot=True,
fmt='.2f',
cmap='coolwarm',
center=0,
square=True,
linewidths=0.5,
cbar_kws={"shrink": 0.8},
ax=ax
)
ax.set_title('Correlation Matrix (Pearson)', fontweight='bold')
plt.tight_layout()
plt.savefig(os.path.join(plots_dir, 'correlation_matrix.png'),
dpi=150, bbox_inches='tight')
plt.close()
# 2. Target variable correlations
if target_col in corr_matrix.columns:
target_corrs = corr_matrix[target_col].drop(target_col, errors='ignore')
if not target_corrs.empty:
fig, ax = plt.subplots(figsize=(10, 8))
top_corrs = target_corrs.abs().sort_values(ascending=True).tail(20)
colors = ['red' if target_corrs[feat] < 0 else 'blue'
for feat in top_corrs.index]
ax.barh(range(len(top_corrs)), top_corrs.values, color=colors)
ax.set_yticks(range(len(top_corrs)))
ax.set_yticklabels(top_corrs.index)
ax.set_xlabel('Absolute correlation')
ax.set_title(f'Top-20 correlations with {target_col}', fontweight='bold')
ax.grid(True, alpha=0.3, axis='x')
plt.tight_layout()
plt.savefig(os.path.join(plots_dir, 'target_correlations.png'),
dpi=150, bbox_inches='tight')
plt.close()
# 3. VIF scores plot
if vif_results['scores']:
valid_scores = {k: v for k, v in vif_results['scores'].items()
if not pd.isna(v)}
if valid_scores:
fig, ax = plt.subplots(figsize=(12, 8))
sorted_scores = dict(sorted(valid_scores.items(),
key=lambda x: x[1],
reverse=True)[:25])
colors = []
for vif in sorted_scores.values():
if vif > 100:
colors.append('red')
elif vif > 10:
colors.append('orange')
elif vif > 5:
colors.append('yellow')
else:
colors.append('green')
bars = ax.barh(list(sorted_scores.keys()),
list(sorted_scores.values()),
color=colors, edgecolor='black')
ax.set_xlabel('VIF Score')
ax.set_title('VIF Scores (multicollinearity)', fontweight='bold')
ax.axvline(x=5, color='yellow', linestyle='--', alpha=0.7)
ax.axvline(x=10, color='orange', linestyle='--', alpha=0.7)
ax.axvline(x=100, color='red', linestyle='--', alpha=0.7)
ax.grid(True, alpha=0.3, axis='x')
plt.tight_layout()
plt.savefig(os.path.join(plots_dir, 'vif_scores.png'),
dpi=150, bbox_inches='tight')
plt.close()
# 4. High correlations plot
if high_correlations:
fig, ax = plt.subplots(figsize=(12, 8))
# Limit number for display
display_corrs = high_correlations[:15]
# Create labels for feature pairs
labels = [f"{corr['feature1']}{corr['feature2']}"
for corr in display_corrs]
values = [corr['correlation'] for corr in display_corrs]
colors = ['red' if v < 0 else 'blue' for v in values]
y_pos = np.arange(len(display_corrs))
ax.barh(y_pos, values, color=colors)
ax.set_yticks(y_pos)
ax.set_yticklabels(labels, fontsize=9)
ax.invert_yaxis()
ax.set_xlabel('Correlation')
ax.set_title('High correlations (> 0.8)', fontweight='bold')
ax.grid(True, alpha=0.3, axis='x')
plt.tight_layout()
plt.savefig(os.path.join(plots_dir, 'high_correlations.png'),
dpi=150, bbox_inches='tight')
plt.close()
logger.info(f"Visualisations saved to {plots_dir}")
except Exception as e:
logger.warning(f"Error creating visualisations: {e}")
def _log_analysis_results(
self,
corr_matrix: pd.DataFrame,
high_correlations: List[Dict[str, Any]],
target_correlations: List[Dict[str, Any]],
vif_results: Dict[str, Any]
) -> None:
"""Log analysis results"""
logger.info("\n" + "="*80)
logger.info("CORRELATION AND MULTICOLLINEARITY ANALYSIS REPORT")
logger.info("="*80)
# General information
logger.info(f"\n📊 GENERAL INFORMATION:")
logger.info(f" Correlation matrix size: {corr_matrix.shape}")
logger.info(f" Total features: {len(corr_matrix.columns)}")
# High correlations
if high_correlations:
logger.info(f"\n⚠ HIGH CORRELATIONS (|r| > 0.8): {len(high_correlations)}")
logger.info(" " + "-" * 60)
for i, corr in enumerate(high_correlations[:10]):
sign = "🟥" if corr['correlation'] < 0 else "🟩"
logger.info(f" {i+1:2d}. {sign} {corr['feature1']:25s}{corr['feature2']:25s}: {corr['correlation']:7.4f}")
if len(high_correlations) > 10:
logger.info(f" ... and {len(high_correlations) - 10} more pairs")
# Target variable correlations
if target_correlations:
logger.info(f"\n🎯 CORRELATIONS WITH TARGET VARIABLE:")
logger.info(" " + "-" * 60)
for i, corr in enumerate(target_correlations[:10]):
direction = "↓" if corr['correlation'] < 0 else "↑"
logger.info(f" {i+1:2d}. {direction} {corr['feature']:35s}: {corr['correlation']:7.4f}")
# Multicollinearity analysis
if vif_results['scores']:
logger.info(f"\n📈 MULTICOLLINEARITY ANALYSIS (VIF):")
logger.info(" " + "-" * 60)
logger.info(f" Critical (VIF > 100): {vif_results['summary']['critical']}")
logger.info(f" High (10 < VIF ≤ 100): {vif_results['summary']['high']}")
logger.info(f" Medium (5 < VIF ≤ 10): {vif_results['summary']['medium']}")
logger.info(f" Low (VIF ≤ 5): {vif_results['summary']['low']}")
# Top problematic features
if vif_results['issues']:
logger.info(f"\n🔴 PROBLEMATIC FEATURES (VIF > 10):")
for issue in vif_results['issues'][:10]:
logger.info(f" • {issue['feature']:35s}: VIF = {issue['vif']:7.1f} ({issue['severity']})")
logger.info("\n" + "="*80)
logger.info("RECOMMENDATIONS:")
logger.info("="*80)
# Generate recommendations
recommendations = []
if len(high_correlations) > 20:
recommendations.append("1. Remove highly correlated features (correlation method)")
if vif_results['summary']['critical'] > 0:
recommendations.append("2. Remove features with critical VIF (>100)")
if vif_results['summary']['high'] > 5:
recommendations.append("3. Consider removing features with VIF > 10")
if not recommendations:
recommendations.append("1. Data in good condition, no serious issues detected")
recommendations.append("2. Proceed to modelling")
for i, rec in enumerate(recommendations, 1):
logger.info(f" {rec}")
logger.info("\n" + "="*80)
def remove_highly_correlated(
self,
data: pd.DataFrame,
threshold: float = 0.85,
method: str = 'variance',
keep_target: bool = True,
keep_features: List[str] = None
) -> pd.DataFrame:
"""
Remove highly correlated features
Parameters:
-----------
data : pd.DataFrame
Source data
threshold : float
Correlation threshold for removal
method : str
Feature selection method for removal: 'variance', 'random', 'importance'
keep_target : bool
Whether to keep target variable
keep_features : List[str], optional
Features to keep
Returns:
--------
pd.DataFrame
Data after removing highly correlated features
"""
logger.info("\n" + "="*80)
logger.info("REMOVING HIGHLY CORRELATED FEATURES")
logger.info("="*80)
data_clean = data.copy()
if 'pearson' not in self.correlation_matrices:
logger.warning("Correlation matrix not calculated, run analyze() first")
return data_clean
corr_matrix = self.correlation_matrices['pearson']
# Features to keep
features_to_keep = set()
if keep_target and self.config.target_column in data_clean.columns:
features_to_keep.add(self.config.target_column)
if keep_features:
for feat in keep_features:
if feat in data_clean.columns:
features_to_keep.add(feat)
# Temporal features (usually important for time series)
temporal_patterns = ['year', 'month', 'day', 'week', 'quarter',
'hour', 'minute', 'second', 'sin', 'cos']
for col in data_clean.columns:
if any(pattern in col.lower() for pattern in temporal_patterns):
features_to_keep.add(col)
# Find highly correlated pairs
upper_triangle = corr_matrix.where(
np.triu(np.ones(corr_matrix.shape), k=1).astype(bool)
)
# Collect highly correlated features
correlated_features = set()
for col in upper_triangle.columns:
if col in features_to_keep:
continue
high_corr = upper_triangle[col][abs(upper_triangle[col]) > threshold]
for row_idx, corr_value in high_corr.items():
if not pd.isna(corr_value) and row_idx not in features_to_keep:
# Select which feature to remove
if method == 'variance':
# Remove the one with lower variance
var_col = data_clean[col].var()
var_row = data_clean[row_idx].var()
feature_to_remove = col if var_col < var_row else row_idx
elif method == 'importance':
# Remove the one with lower correlation to target variable
if self.config.target_column in corr_matrix.columns:
corr_col_target = abs(corr_matrix.loc[col, self.config.target_column])
corr_row_target = abs(corr_matrix.loc[row_idx, self.config.target_column])
feature_to_remove = col if corr_col_target < corr_row_target else row_idx
else:
# If no target, remove randomly
feature_to_remove = np.random.choice([col, row_idx])
else:
# Remove randomly
feature_to_remove = np.random.choice([col, row_idx])
correlated_features.add(feature_to_remove)
# Remove features
features_to_remove = list(correlated_features)
if features_to_remove:
data_clean = data_clean.drop(columns=features_to_remove)
logger.info(f"\n📊 REMOVAL RESULTS:")
logger.info(f" Initial feature count: {len(data.columns)}")
logger.info(f" Features removed: {len(features_to_remove)}")
logger.info(f" Final feature count: {len(data_clean.columns)}")
logger.info(f" Retained: {len(data_clean.columns)/len(data.columns)*100:.1f}%")
if features_to_remove:
logger.info(f"\n🗑️ REMOVED FEATURES:")
for i, feat in enumerate(sorted(features_to_remove)[:20]):
logger.info(f" {i+1:2d}. {feat}")
if len(features_to_remove) > 20:
logger.info(f" ... and {len(features_to_remove) - 20} more features")
else:
logger.info("✓ No highly correlated features detected, all features retained")
logger.info("="*80)
return data_clean
def get_report(self) -> Dict[str, Any]:
"""Get analysis report"""
report = {
"correlation_matrix_shape": None,
"high_correlation_count": 0,
"vif_summary": {},
"target_correlation_count": 0
}
if 'pearson' in self.correlation_matrices:
report["correlation_matrix_shape"] = self.correlation_matrices['pearson'].shape
if 'pearson' in self.high_correlation_pairs:
report["high_correlation_count"] = len(self.high_correlation_pairs['pearson'])
if self.vif_scores:
report["vif_summary"] = self.vif_scores.get('summary', {})
return report