# ============================================ # CLASS 13: VISUALISATION MANAGER (UPDATED) # ============================================ import os from datetime import datetime import json from typing import Dict, List, Optional, Tuple, Union, Any import pandas as pd import numpy as np from statsmodels.graphics.tsaplots import plot_acf, plot_pacf import matplotlib.pyplot as plt import seaborn as sns from scipy.stats import gaussian_kde import matplotlib matplotlib.use('Agg') # Use non-display backend from config.config import Config import logging # Logging setup logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class VisualisationManager: """Class for managing all visualisations""" def __init__(self, config: Config): """ Initialise visualisation manager Parameters: ----------- config : Config Experiment configuration """ self.config = config self.plots_generated = {} self.plot_files = {} self.figure_count = 0 # Create directory structure for saving plots self._create_directory_structure() def _create_directory_structure(self) -> None: """Create directory structure for saving plots""" base_dir = self.config.results_dir # Main plot directories self.plots_dir = os.path.join(base_dir, "plots") self.correlations_dir = os.path.join(base_dir, "plots", "correlations") self.distributions_dir = os.path.join(base_dir, "plots", "distributions") self.features_dir = os.path.join(base_dir, "plots", "features") self.time_series_dir = os.path.join(base_dir, "plots", "time_series") self.preprocessing_dir = os.path.join(base_dir, "plots", "preprocessing") self.summary_dir = os.path.join(base_dir, "plots", "summary") self.reports_dir = os.path.join(base_dir, "reports") # Create directories directories = [ self.plots_dir, self.correlations_dir, self.distributions_dir, self.features_dir, self.time_series_dir, self.preprocessing_dir, self.summary_dir, self.reports_dir ] for directory in directories: os.makedirs(directory, exist_ok=True) logger.debug(f"Created directory: {directory}") def _save_figure(self, fig: plt.Figure, filename: str, subdirectory: str = None, dpi: int = 300) -> str: """ Save plot and close it Parameters: ----------- fig : matplotlib.figure.Figure Plot figure object filename : str Filename for saving subdirectory : str, optional Subdirectory for saving dpi : int Save quality Returns: -------- str : full path to saved file """ if not filename.endswith('.png'): filename = f"{filename}.png" if subdirectory: save_dir = os.path.join(self.plots_dir, subdirectory) os.makedirs(save_dir, exist_ok=True) else: save_dir = self.plots_dir filepath = os.path.join(save_dir, filename) try: fig.savefig(filepath, dpi=dpi, bbox_inches='tight', facecolor='white') logger.info(f"✓ Plot saved: {filepath}") except Exception as e: logger.error(f"✗ Error saving plot {filename}: {e}") filepath = None # Close plot without display plt.close(fig) return filepath # ============================================ # MAIN VISUALISATION METHODS # ============================================ def create_summary_dashboard( self, data: pd.DataFrame, preprocessing_stages: Dict = None, filename: str = "summary_dashboard" ) -> str: """ Create summary visualisation dashboard Parameters: ----------- data : pd.DataFrame Data for visualisation preprocessing_stages : Dict, optional Preprocessing stages information filename : str Filename for saving Returns: -------- str : path to saved file or None if error """ logger.info("\n" + "="*80) logger.info("CREATING SUMMARY DASHBOARD") logger.info("="*80) target_col = self.config.target_column try: # Create large dashboard fig = plt.figure(figsize=(20, 24)) gs = fig.add_gridspec(6, 4, hspace=0.3, wspace=0.3) # 1. Time series of target variable ax1 = fig.add_subplot(gs[0, :2]) if target_col in data.columns and isinstance(data.index, pd.DatetimeIndex): ax1.plot(data.index, data[target_col], linewidth=1, color='blue', alpha=0.7) ax1.set_title(f'Time Series: {target_col}', fontsize=12, fontweight='bold') ax1.set_xlabel('Date', fontsize=10) ax1.set_ylabel(target_col, fontsize=10) ax1.grid(True, alpha=0.3) ax1.tick_params(axis='x', rotation=45) else: ax1.text(0.5, 0.5, 'No time series data available', ha='center', va='center', transform=ax1.transAxes) # 2. Target variable distribution ax2 = fig.add_subplot(gs[0, 2:]) if target_col in data.columns: values = data[target_col].dropna() if len(values) > 0: ax2.hist(values, bins=30, edgecolor='black', alpha=0.7, color='green') ax2.set_title(f'Distribution: {target_col}', fontsize=12, fontweight='bold') ax2.set_xlabel(target_col, fontsize=10) ax2.set_ylabel('Frequency', fontsize=10) ax2.grid(True, alpha=0.3) else: ax2.text(0.5, 0.5, 'No data for distribution', ha='center', va='center', transform=ax2.transAxes) # 3. Correlation matrix (top features) ax3 = fig.add_subplot(gs[1, :]) numeric_cols = data.select_dtypes(include=[np.number]).columns if len(numeric_cols) > 1: display_cols = list(numeric_cols[:15]) if target_col not in display_cols and target_col in data.columns: display_cols = [target_col] + [c for c in display_cols if c != target_col][:14] corr_matrix = data[display_cols].corr() mask = np.triu(np.ones_like(corr_matrix, dtype=bool)) im = ax3.imshow(corr_matrix.where(~mask), cmap='coolwarm', vmin=-1, vmax=1, aspect='auto') ax3.set_title('Correlation Matrix (Top 15 Features)', fontsize=12, fontweight='bold') ax3.set_xticks(range(len(display_cols))) ax3.set_yticks(range(len(display_cols))) ax3.set_xticklabels(display_cols, rotation=90, fontsize=8) ax3.set_yticklabels(display_cols, fontsize=8) plt.colorbar(im, ax=ax3, shrink=0.8) # 4. Seasonal patterns ax4 = fig.add_subplot(gs[2, :2]) if target_col in data.columns and isinstance(data.index, pd.DatetimeIndex): data_copy = data.copy() data_copy['month'] = data_copy.index.month monthly_avg = data_copy.groupby('month')[target_col].mean() colors = plt.cm.Set3(np.linspace(0, 1, len(monthly_avg))) ax4.bar(monthly_avg.index, monthly_avg.values, color=colors, edgecolor='black') ax4.set_title('Average Values by Month', fontsize=12, fontweight='bold') ax4.set_xlabel('Month', fontsize=10) ax4.set_ylabel(f'Average {target_col}', fontsize=10) ax4.set_xticks(range(1, 13)) month_names = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'] ax4.set_xticklabels(month_names) ax4.grid(True, alpha=0.3, axis='y') # 5. Weekly patterns ax5 = fig.add_subplot(gs[2, 2:]) if target_col in data.columns and isinstance(data.index, pd.DatetimeIndex): data_copy = data.copy() data_copy['dayofweek'] = data_copy.index.dayofweek daily_avg = data_copy.groupby('dayofweek')[target_col].mean() colors = plt.cm.Paired(np.linspace(0, 1, len(daily_avg))) ax5.bar(daily_avg.index, daily_avg.values, color=colors, edgecolor='black') ax5.set_title('Average Values by Day of Week', fontsize=12, fontweight='bold') ax5.set_xlabel('Day of Week', fontsize=10) ax5.set_ylabel(f'Average {target_col}', fontsize=10) ax5.set_xticks(range(7)) ax5.set_xticklabels(['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']) ax5.grid(True, alpha=0.3, axis='y') # 6. Trend and seasonality ax6 = fig.add_subplot(gs[3, :]) if target_col in data.columns and len(data) > 30: try: window_size = min(365, len(data) // 10) if window_size >= 7: rolling_mean = data[target_col].rolling(window=window_size, center=True).mean() rolling_std = data[target_col].rolling(window=window_size, center=True).std() ax6.plot(data.index, data[target_col], alpha=0.5, label='Original Series', linewidth=0.5, color='blue') ax6.plot(rolling_mean.index, rolling_mean, label=f'Rolling Mean ({window_size} days)', color='red', linewidth=2) ax6.fill_between(rolling_mean.index, rolling_mean - rolling_std, rolling_mean + rolling_std, alpha=0.2, color='red') ax6.set_title('Trend and Volatility', fontsize=12, fontweight='bold') ax6.set_xlabel('Date', fontsize=10) ax6.set_ylabel(target_col, fontsize=10) ax6.legend(fontsize=9, loc='upper left') ax6.grid(True, alpha=0.3) else: ax6.text(0.5, 0.5, 'Insufficient data for trend analysis', ha='center', va='center', transform=ax6.transAxes) except Exception as e: logger.warning(f"Error plotting trend: {e}") ax6.text(0.5, 0.5, 'Error plotting trend', ha='center', va='center', transform=ax6.transAxes) # 7. Preprocessing statistics if preprocessing_stages: ax7 = fig.add_subplot(gs[4, :2]) stages = list(preprocessing_stages.keys()) values = list(preprocessing_stages.values()) colors = plt.cm.viridis(np.linspace(0.3, 0.9, len(stages))) bars = ax7.bar(range(len(stages)), values, color=colors, edgecolor='black') ax7.set_title('Preprocessing Statistics', fontsize=12, fontweight='bold') ax7.set_xlabel('Processing Stage', fontsize=10) ax7.set_ylabel('Value', fontsize=10) ax7.set_xticks(range(len(stages))) ax7.set_xticklabels([s[:15] + '...' if len(s) > 15 else s for s in stages], rotation=45, ha='right', fontsize=9) ax7.grid(True, alpha=0.3, axis='y') # Add values on bars for bar, value in zip(bars, values): height = bar.get_height() ax7.text(bar.get_x() + bar.get_width()/2., height, f'{value:.2f}', ha='center', va='bottom', fontsize=8) # 8. Data information ax8 = fig.add_subplot(gs[4, 2:]) ax8.axis('off') info_text = [] info_text.append("GENERAL CHARACTERISTICS:") info_text.append(f"• Number of records: {len(data):,}") info_text.append(f"• Number of features: {len(data.columns)}") if isinstance(data.index, pd.DatetimeIndex): info_text.append(f"• Period: {data.index.min().strftime('%Y-%m-%d')} - " f"{data.index.max().strftime('%Y-%m-%d')}") info_text.append(f"• Days of data: {(data.index.max() - data.index.min()).days}") if target_col in data.columns: target_stats = data[target_col].describe() info_text.append(f"\nTARGET VARIABLE '{target_col}':") info_text.append(f"• Mean: {target_stats['mean']:.2f}") info_text.append(f"• Standard deviation: {target_stats['std']:.2f}") info_text.append(f"• Minimum: {target_stats['min']:.2f}") info_text.append(f"• 25%: {target_stats['25%']:.2f}") info_text.append(f"• 50% (median): {target_stats['50%']:.2f}") info_text.append(f"• 75%: {target_stats['75%']:.2f}") info_text.append(f"• Maximum: {target_stats['max']:.2f}") info_text.append(f"\nDATA TYPES:") for dtype, count in data.dtypes.value_counts().items(): info_text.append(f"• {dtype}: {count} columns") missing_info = data.isnull().sum() missing_total = missing_info.sum() missing_percent = missing_total / data.size * 100 info_text.append(f"\nMISSING VALUES:") info_text.append(f"• Total missing: {missing_total:,}") info_text.append(f"• Missing percentage: {missing_percent:.2f}%") if missing_total > 0: top_missing = missing_info.nlargest(5) info_text.append(f"• Top 5 columns with missing values:") for col, count in top_missing.items(): percent = count / len(data) * 100 info_text.append(f" {col}: {count} ({percent:.1f}%)") ax8.text(0.02, 0.98, '\n'.join(info_text), transform=ax8.transAxes, fontsize=8, verticalalignment='top', fontfamily='monospace') # 9. Autocorrelation plot ax9 = fig.add_subplot(gs[5, :2]) if target_col in data.columns: try: series = data[target_col].dropna() if len(series) > 50: plot_acf(series, lags=min(50, len(series)-1), ax=ax9, alpha=0.05) ax9.set_title('Autocorrelation Function (ACF)', fontsize=12, fontweight='bold') ax9.set_xlabel('Lag', fontsize=10) ax9.set_ylabel('Autocorrelation', fontsize=10) ax9.grid(True, alpha=0.3) else: ax9.text(0.5, 0.5, 'Insufficient data for ACF', ha='center', va='center', transform=ax9.transAxes) except Exception as e: logger.warning(f"Error plotting ACF: {e}") ax9.text(0.5, 0.5, 'Error calculating ACF', ha='center', va='center', transform=ax9.transAxes) # 10. Partial autocorrelation plot ax10 = fig.add_subplot(gs[5, 2:]) if target_col in data.columns: try: series = data[target_col].dropna() if len(series) > 50: plot_pacf(series, lags=min(50, len(series)-1), ax=ax10, alpha=0.05) ax10.set_title('Partial Autocorrelation Function (PACF)', fontsize=12, fontweight='bold') ax10.set_xlabel('Lag', fontsize=10) ax10.set_ylabel('Partial Autocorrelation', fontsize=10) ax10.grid(True, alpha=0.3) else: ax10.text(0.5, 0.5, 'Insufficient data for PACF', ha='center', va='center', transform=ax10.transAxes) except Exception as e: logger.warning(f"Error plotting PACF: {e}") ax10.text(0.5, 0.5, 'Error calculating PACF', ha='center', va='center', transform=ax10.transAxes) plt.suptitle('Data Analysis Summary Dashboard', fontsize=16, fontweight='bold', y=0.98) plt.tight_layout() # Save filepath = self._save_figure(fig, filename, "summary") self.plot_files['summary_dashboard'] = filepath return filepath except Exception as e: logger.error(f"Error creating summary dashboard: {e}") return None # ============================================ # SPECIFIC METHODS FOR SAVING YOUR PLOTS # ============================================ def save_data_split_plot(self, filename: str = "data_split.png") -> str: """ Save data split plot Parameters: ----------- filename : str Filename for saving Returns: -------- str : path to saved file """ try: fig = plt.gcf() # Get current figure filepath = self._save_figure(fig, filename, "time_series") self.plot_files['data_split'] = filepath return filepath except Exception as e: logger.error(f"Error saving data_split plot: {e}") return None def save_feature_selection_correlation_plot(self, filename: str = "feature_selection_correlation.png") -> str: """ Save feature selection correlation plot Parameters: ----------- filename : str Filename for saving Returns: -------- str : path to saved file """ try: fig = plt.gcf() # Get current figure filepath = self._save_figure(fig, filename, "correlations") self.plot_files['feature_selection_correlation'] = filepath return filepath except Exception as e: logger.error(f"Error saving feature_selection_correlation plot: {e}") return None def save_missing_values_analysis_plot(self, filename: str = "missing_values_analysis.png") -> str: """ Save missing values analysis plot Parameters: ----------- filename : str Filename for saving Returns: -------- str : path to saved file """ try: fig = plt.gcf() # Get current figure filepath = self._save_figure(fig, filename, "preprocessing") self.plot_files['missing_values_analysis'] = filepath return filepath except Exception as e: logger.error(f"Error saving missing_values_analysis plot: {e}") return None def save_outlier_handling_results_plot(self, filename: str = "outlier_handling_results.png") -> str: """ Save outlier handling results plot Parameters: ----------- filename : str Filename for saving Returns: -------- str : path to saved file """ try: fig = plt.gcf() # Get current figure filepath = self._save_figure(fig, filename, "preprocessing") self.plot_files['outlier_handling_results'] = filepath return filepath except Exception as e: logger.error(f"Error saving outlier_handling_results plot: {e}") return None def save_outliers_analysis_plot(self, filename: str = "outliers_analysis.png") -> str: """ Save outliers analysis plot Parameters: ----------- filename : str Filename for saving Returns: -------- str : path to saved file """ try: fig = plt.gcf() # Get current figure filepath = self._save_figure(fig, filename, "preprocessing") self.plot_files['outliers_analysis'] = filepath return filepath except Exception as e: logger.error(f"Error saving outliers_analysis plot: {e}") return None def save_scaling_results_plot(self, filename: str = "scaling_results.png") -> str: """ Save scaling results plot Parameters: ----------- filename : str Filename for saving Returns: -------- str : path to saved file """ try: fig = plt.gcf() # Get current figure filepath = self._save_figure(fig, filename, "preprocessing") self.plot_files['scaling_results'] = filepath return filepath except Exception as e: logger.error(f"Error saving scaling_results plot: {e}") return None def save_stationarity_analysis_plot(self, filename: str = "stationarity_analysis.png") -> str: """ Save stationarity analysis plot Parameters: ----------- filename : str Filename for saving Returns: -------- str : path to saved file """ try: fig = plt.gcf() # Get current figure filepath = self._save_figure(fig, filename, "time_series") self.plot_files['stationarity_analysis'] = filepath return filepath except Exception as e: logger.error(f"Error saving stationarity_analysis plot: {e}") return None def save_temporal_outliers_plot(self, filename: str = "temporal_outliers.png") -> str: """ Save temporal outliers plot Parameters: ----------- filename : str Filename for saving Returns: -------- str : path to saved file """ try: fig = plt.gcf() # Get current figure filepath = self._save_figure(fig, filename, "time_series") self.plot_files['temporal_outliers'] = filepath return filepath except Exception as e: logger.error(f"Error saving temporal_outliers plot: {e}") return None # ============================================ # UNIVERSAL METHOD FOR SAVING ANY PLOT # ============================================ def save_current_plot(self, filename: str, subdirectory: str = None) -> str: """ Universal method for saving current plot Parameters: ----------- filename : str Filename for saving subdirectory : str, optional Subdirectory for saving Returns: -------- str : path to saved file """ try: fig = plt.gcf() # Get current figure filepath = self._save_figure(fig, filename, subdirectory) # Save plot information plot_key = filename.replace('.png', '').replace('.jpg', '') self.plot_files[plot_key] = filepath return filepath except Exception as e: logger.error(f"Error saving plot {filename}: {e}") return None # ============================================ # ADDITIONAL VISUALISATION METHODS # ============================================ def create_feature_importance_plot( self, feature_importance: Dict, top_n: int = 20, filename: str = "feature_importance" ) -> str: """ Create feature importance plot Parameters: ----------- feature_importance : Dict Dictionary with feature importance top_n : int Number of top features to display filename : str Filename for saving Returns: -------- str : path to saved file or None if error """ if not feature_importance: logger.warning("No feature importance data for visualisation") return None try: # Convert to Series and sort importance_series = pd.Series(feature_importance).sort_values(ascending=False) top_features = importance_series.head(top_n) # Create plot fig, ax = plt.subplots(figsize=(12, 8)) y_pos = np.arange(len(top_features)) colors = plt.cm.plasma(np.linspace(0.2, 0.9, len(top_features))) bars = ax.barh(y_pos, top_features.values, color=colors, edgecolor='black') ax.set_yticks(y_pos) ax.set_yticklabels(top_features.index, fontsize=10) ax.invert_yaxis() ax.set_xlabel('Feature Importance', fontsize=11, fontweight='bold') ax.set_title(f'Top-{top_n} Most Important Features', fontsize=14, fontweight='bold') ax.grid(True, alpha=0.3, axis='x') # Add values on bars for i, (bar, value) in enumerate(zip(bars, top_features.values)): width = bar.get_width() ax.text(width * 1.01, bar.get_y() + bar.get_height()/2, f'{value:.4f}', va='center', fontsize=9, fontweight='bold') # Add additional information plt.text(0.02, 0.98, f'Total features: {len(importance_series)}', transform=fig.transFigure, fontsize=9, verticalalignment='top') plt.tight_layout() # Save filepath = self._save_figure(fig, filename, "features") self.plot_files['feature_importance'] = filepath return filepath except Exception as e: logger.error(f"Error creating feature importance plot: {e}") return None def create_correlation_heatmap( self, data: pd.DataFrame, top_n: int = 20, filename: str = "correlation_heatmap" ) -> Tuple[str, Optional[str]]: """ Create correlation heatmap Parameters: ----------- data : pd.DataFrame Data for analysis top_n : int Number of top features to display filename : str Filename for saving Returns: -------- Tuple[str, Optional[str]]: (path to main heatmap, path to target correlation heatmap) """ target_col = self.config.target_column try: numeric_cols = data.select_dtypes(include=[np.number]).columns.tolist() if len(numeric_cols) < 2: logger.warning("Insufficient numeric features for correlation analysis") return None, None # Create two heatmaps # 1. Main correlation heatmap between all features main_filepath = self._create_main_correlation_heatmap(data, numeric_cols, top_n, filename) # 2. Target correlation heatmap target_filepath = None if target_col in data.columns and target_col in numeric_cols: target_filepath = self._create_target_correlation_heatmap(data, target_col, numeric_cols, filename) return main_filepath, target_filepath except Exception as e: logger.error(f"Error creating correlation heatmap: {e}") return None, None def _create_main_correlation_heatmap( self, data: pd.DataFrame, numeric_cols: List[str], top_n: int, filename: str ) -> str: """Create main correlation heatmap""" # Limit number of features for better readability if len(numeric_cols) > top_n: # Select features with highest variance variances = data[numeric_cols].var().sort_values(ascending=False) selected_cols = variances.head(top_n).index.tolist() else: selected_cols = numeric_cols # Calculate correlation corr_matrix = data[selected_cols].corr() fig, ax = plt.subplots(figsize=(14, 12)) # Mask for upper triangle mask = np.triu(np.ones_like(corr_matrix, dtype=bool)) # Create heatmap sns.heatmap( corr_matrix, annot=True, fmt='.2f', cmap='coolwarm', center=0, square=True, mask=mask, cbar_kws={'shrink': 0.8, 'label': 'Correlation Coefficient'}, linewidths=0.5, linecolor='white', ax=ax, annot_kws={'size': 8} ) ax.set_title(f'Correlation Matrix Between Features (Top-{top_n})', fontsize=14, fontweight='bold', pad=20) plt.tight_layout() # Save filepath = self._save_figure(fig, filename, "correlations") self.plot_files['correlation_heatmap_main'] = filepath return filepath def _create_target_correlation_heatmap( self, data: pd.DataFrame, target_col: str, numeric_cols: List[str], filename: str ) -> str: """Create target correlation heatmap""" # Calculate correlations with target variable correlations = data[numeric_cols].corrwith(data[target_col]).sort_values(key=abs, ascending=False) # Exclude target variable itself correlations = correlations[correlations.index != target_col] # Take top 15 features top_features = correlations.head(15) fig, ax = plt.subplots(figsize=(10, 8)) colors = ['red' if x < 0 else 'green' for x in top_features.values] bars = ax.barh(range(len(top_features)), top_features.values, color=colors, edgecolor='black') ax.set_yticks(range(len(top_features))) ax.set_yticklabels(top_features.index, fontsize=10) ax.invert_yaxis() ax.set_xlabel('Correlation Coefficient', fontsize=11, fontweight='bold') ax.set_title(f'Feature Correlations with Target Variable "{target_col}"', fontsize=14, fontweight='bold', pad=20) ax.grid(True, alpha=0.3, axis='x') ax.axvline(x=0, color='black', linestyle='-', linewidth=0.5) # Add values on bars for bar, value in zip(bars, top_features.values): width = bar.get_width() ax.text(width + (0.01 if width >= 0 else -0.04), bar.get_y() + bar.get_height()/2, f'{value:.3f}', va='center', ha='left' if width >= 0 else 'right', fontsize=9, fontweight='bold', color='black') plt.tight_layout() # Save target_filename = f"{filename}_with_target" filepath = self._save_figure(fig, target_filename, "correlations") self.plot_files['correlation_with_target'] = filepath return filepath def create_distribution_comparison( self, original_data: pd.DataFrame, processed_data: pd.DataFrame, columns: List[str] = None, max_columns: int = 12, filename: str = "distribution_comparison" ) -> str: """ Compare distributions before and after processing Parameters: ----------- original_data : pd.DataFrame Original data processed_data : pd.DataFrame Processed data columns : List[str], optional List of columns to compare max_columns : int Maximum number of columns to display filename : str Filename for saving Returns: -------- str : path to saved file or None if error """ try: if columns is None: # Select numeric columns common to both datasets numeric_cols_original = original_data.select_dtypes(include=[np.number]).columns numeric_cols_processed = processed_data.select_dtypes(include=[np.number]).columns common_cols = list(set(numeric_cols_original) & set(numeric_cols_processed)) # Sort by variance in original data variances = original_data[common_cols].var().sort_values(ascending=False) columns = variances.head(max_columns).index.tolist() n_cols = min(4, len(columns)) n_rows = (len(columns) + n_cols - 1) // n_cols fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 4, n_rows * 3.5)) fig.suptitle('Distribution Comparison Before and After Processing', fontsize=16, fontweight='bold', y=0.98) if n_rows == 1 and n_cols == 1: axes = np.array([axes]) axes = axes.flat if hasattr(axes, 'flat') else [axes] for idx, col in enumerate(columns): if idx >= len(axes): break ax = axes[idx] if col in original_data.columns and col in processed_data.columns: original_values = original_data[col].dropna() processed_values = processed_data[col].dropna() if len(original_values) > 0 and len(processed_values) > 0: # Use common bins for comparison all_values = pd.concat([original_values, processed_values]) bins = np.histogram_bin_edges(all_values, bins=30) # Histograms ax.hist(original_values, bins=bins, alpha=0.5, label='Before Processing', density=True, color='blue') ax.hist(processed_values, bins=bins, alpha=0.5, label='After Processing', density=True, color='orange') # Add KDE try: if len(original_values) > 10: kde_original = gaussian_kde(original_values) x_range = np.linspace(original_values.min(), original_values.max(), 100) ax.plot(x_range, kde_original(x_range), 'b-', linewidth=1.5, alpha=0.8) if len(processed_values) > 10: kde_processed = gaussian_kde(processed_values) x_range = np.linspace(processed_values.min(), processed_values.max(), 100) ax.plot(x_range, kde_processed(x_range), 'orange', linewidth=1.5, alpha=0.8) except: pass # Add statistics stats_text = [] if len(original_values) > 0: stats_text.append(f"Before: μ={original_values.mean():.2f}, σ={original_values.std():.2f}") if len(processed_values) > 0: stats_text.append(f"After: μ={processed_values.mean():.2f}, σ={processed_values.std():.2f}") if stats_text: ax.text(0.02, 0.98, '\n'.join(stats_text), transform=ax.transAxes, fontsize=8, verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5)) ax.set_title(f'{col}', fontsize=11, fontweight='bold') ax.set_xlabel('Value', fontsize=9) ax.set_ylabel('Density', fontsize=9) ax.legend(fontsize=8) ax.grid(True, alpha=0.3) else: ax.text(0.5, 0.5, 'No data', ha='center', va='center', transform=ax.transAxes) else: ax.text(0.5, 0.5, 'Column not found', ha='center', va='center', transform=ax.transAxes) # Hide unused subplots for idx in range(len(columns), len(axes)): axes[idx].set_visible(False) plt.tight_layout() # Save filepath = self._save_figure(fig, filename, "distributions") self.plot_files['distribution_comparison'] = filepath return filepath except Exception as e: logger.error(f"Error creating distribution comparison: {e}") return None def create_time_series_decomposition_plot( self, decomposition_result: Dict, filename: str = "time_series_decomposition" ) -> str: """ Visualise time series decomposition Parameters: ----------- decomposition_result : Dict Decomposition results filename : str Filename for saving Returns: -------- str : path to saved file or None if error """ target_col = self.config.target_column try: fig, axes = plt.subplots(4, 1, figsize=(14, 10)) fig.suptitle(f'Time Series Decomposition: {target_col}', fontsize=16, fontweight='bold', y=0.98) # Original series if 'observed' in decomposition_result: observed = decomposition_result['observed'] axes[0].plot(observed, color='blue', linewidth=1.5) axes[0].set_ylabel('Observed', fontsize=11, fontweight='bold') axes[0].grid(True, alpha=0.3) axes[0].set_title('Original Time Series', fontsize=12) # Trend if 'trend' in decomposition_result and decomposition_result['trend'] is not None: trend = decomposition_result['trend'] axes[1].plot(trend, color='red', linewidth=2) axes[1].set_ylabel('Trend', fontsize=11, fontweight='bold') axes[1].grid(True, alpha=0.3) axes[1].set_title('Trend Component', fontsize=12) # Seasonality if 'seasonal' in decomposition_result and decomposition_result['seasonal'] is not None: seasonal = decomposition_result['seasonal'] axes[2].plot(seasonal, color='green', linewidth=1.5) axes[2].set_ylabel('Seasonal', fontsize=11, fontweight='bold') axes[2].grid(True, alpha=0.3) axes[2].set_title('Seasonal Component', fontsize=12) # Residuals if 'residual' in decomposition_result and decomposition_result['residual'] is not None: residual = decomposition_result['residual'] axes[3].plot(residual, color='purple', linewidth=1, alpha=0.7) axes[3].set_ylabel('Residuals', fontsize=11, fontweight='bold') axes[3].set_xlabel('Date', fontsize=11, fontweight='bold') axes[3].grid(True, alpha=0.3) axes[3].set_title('Residual Component', fontsize=12) # Add residual statistics if len(residual) > 0: stats_text = (f"Mean: {residual.mean():.4f}\n" f"Std: {residual.std():.4f}\n" f"Min: {residual.min():.4f}\n" f"Max: {residual.max():.4f}") axes[3].text(0.02, 0.98, stats_text, transform=axes[3].transAxes, fontsize=8, verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5)) plt.tight_layout() # Save filepath = self._save_figure(fig, filename, "time_series") self.plot_files['time_series_decomposition'] = filepath return filepath except Exception as e: logger.error(f"Error creating time series decomposition: {e}") return None def create_data_quality_report( self, validation_results: Dict, filename: str = "data_quality_report" ) -> str: """ Create visual data quality report Parameters: ----------- validation_results : Dict Validation results filename : str Filename for saving Returns: -------- str : path to saved file or None if error """ try: fig = plt.figure(figsize=(16, 12)) fig.suptitle('Data Quality Report', fontsize=18, fontweight='bold', y=0.98) # Use GridSpec for more complex layout gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3) # 1. Quality radar chart (top left) ax1 = fig.add_subplot(gs[0, 0], projection='polar') categories = ['Size', 'Missing', 'Duplicates', 'Stability', 'Informativeness'] # Extract values from validation results if 'quality_metrics' in validation_results: values = [ validation_results['quality_metrics'].get('size_score', 0.5), validation_results['quality_metrics'].get('missing_score', 0.5), validation_results['quality_metrics'].get('duplicates_score', 0.5), validation_results['quality_metrics'].get('stability_score', 0.5), validation_results['quality_metrics'].get('informativeness_score', 0.5) ] else: values = [0.8, 0.7, 0.9, 0.6, 0.8] N = len(categories) angles = [n / float(N) * 2 * np.pi for n in range(N)] angles += angles[:1] values += values[:1] ax1.plot(angles, values, 'o-', linewidth=2, color='blue') ax1.fill(angles, values, alpha=0.25, color='blue') ax1.set_xticks(angles[:-1]) ax1.set_xticklabels(categories, fontsize=10) ax1.set_ylim(0, 1) ax1.set_title('Data Quality Radar Chart', fontsize=12, fontweight='bold') ax1.grid(True) # 2. Check status (top right) ax2 = fig.add_subplot(gs[0, 1]) basic_checks = validation_results.get('basic_checks', {}) checks_passed = sum(1 for check in basic_checks.values() if check.get('passed', False)) checks_total = len(basic_checks) checks_failed = checks_total - checks_passed if checks_total > 0: colors = ['#4CAF50' if checks_passed > 0 else '#FF6B6B', '#FF6B6B' if checks_failed > 0 else '#4CAF50'] bars = ax2.bar(['Passed', 'Failed'], [checks_passed, checks_failed], color=colors, edgecolor='black') ax2.set_title(f'Basic Checks: {checks_passed}/{checks_total}', fontsize=12, fontweight='bold') ax2.set_ylabel('Number of Checks', fontsize=10) ax2.grid(True, alpha=0.3, axis='y') # Add values on bars for bar, value in zip(bars, [checks_passed, checks_failed]): height = bar.get_height() ax2.text(bar.get_x() + bar.get_width()/2., height, f'{value}', ha='center', va='bottom', fontsize=10, fontweight='bold') else: ax2.text(0.5, 0.5, 'No check data available', ha='center', va='center', transform=ax2.transAxes) ax2.set_title('Basic Checks', fontsize=12, fontweight='bold') # 3. Overall score (top right) ax3 = fig.add_subplot(gs[0, 2]) overall_score = validation_results.get('overall_score', 0) status = validation_results.get('status', 'UNKNOWN') # Score pie chart sizes = [overall_score, 100 - overall_score] if overall_score >= 80: colors = ['#4CAF50', '#E0E0E0'] # Green elif overall_score >= 60: colors = ['#FFC107', '#E0E0E0'] # Yellow else: colors = ['#F44336', '#E0E0E0'] # Red wedges, texts, autotexts = ax3.pie(sizes, colors=colors, startangle=90, autopct='%1.1f%%', pctdistance=0.85) # Central text status_colors = {'PASS': '#4CAF50', 'WARNING': '#FFC107', 'FAIL': '#F44336'} status_color = status_colors.get(status, '#757575') ax3.text(0, 0, f'{overall_score}/100\n{status}', ha='center', va='center', fontsize=14, fontweight='bold', color=status_color) ax3.set_title('Overall Quality Score', fontsize=12, fontweight='bold') # 4. Issue distribution by type (left middle) ax4 = fig.add_subplot(gs[1, 0]) issues = validation_results.get('issues', {}) issue_counts = { 'Critical': len(issues.get('critical', [])), 'Warnings': len(issues.get('warning', [])), 'Informational': len(issues.get('info', [])) } if any(issue_counts.values()): colors = ['#F44336', '#FF9800', '#2196F3'] bars = ax4.bar(issue_counts.keys(), issue_counts.values(), color=colors, edgecolor='black') ax4.set_title('Data Issues by Type', fontsize=12, fontweight='bold') ax4.set_ylabel('Number of Issues', fontsize=10) ax4.tick_params(axis='x', rotation=45) ax4.grid(True, alpha=0.3, axis='y') # Add values on bars for bar, value in zip(bars, issue_counts.values()): height = bar.get_height() ax4.text(bar.get_x() + bar.get_width()/2., height, f'{value}', ha='center', va='bottom', fontsize=10, fontweight='bold') else: ax4.text(0.5, 0.5, 'No issues detected', ha='center', va='center', transform=ax4.transAxes, fontsize=12) ax4.set_title('Data Issues', fontsize=12, fontweight='bold') # 5. Detailed information (remaining cells) ax5 = fig.add_subplot(gs[1:, 1:]) ax5.axis('off') # Form text report report_text = [] report_text.append("DETAILED REPORT:") report_text.append("=" * 40) # Basic information report_text.append("\nBASIC INFORMATION:") report_text.append(f"• Overall score: {overall_score}/100") report_text.append(f"• Status: {status}") report_text.append(f"• Checks passed: {checks_passed}/{checks_total}") # Check details if basic_checks: report_text.append("\nCHECK DETAILS:") for check_name, check_result in basic_checks.items(): status_icon = "✓" if check_result.get('passed', False) else "✗" report_text.append(f"• {status_icon} {check_name}: {check_result.get('message', '')}") # Issues if any(issue_counts.values()): report_text.append("\nDETECTED ISSUES:") if issue_counts['Critical'] > 0: report_text.append("\nCRITICAL:") for issue in issues.get('critical', []): report_text.append(f" • {issue}") if issue_counts['Warnings'] > 0: report_text.append("\nWARNINGS:") for issue in issues.get('warning', []): report_text.append(f" • {issue}") if issue_counts['Informational'] > 0: report_text.append("\nINFORMATIONAL:") for issue in issues.get('info', []): report_text.append(f" • {issue}") # Recommendations recommendations = validation_results.get('recommendations', []) if recommendations: report_text.append("\nRECOMMENDATIONS:") for i, rec in enumerate(recommendations, 1): report_text.append(f"{i}. {rec}") ax5.text(0.02, 0.98, '\n'.join(report_text), transform=ax5.transAxes, fontsize=9, verticalalignment='top', fontfamily='monospace', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.1)) plt.tight_layout() # Save filepath = self._save_figure(fig, filename, "reports") self.plot_files['data_quality_report'] = filepath return filepath except Exception as e: logger.error(f"Error creating data quality report: {e}") return None # ============================================ # METHODS FOR BATCH SAVING # ============================================ def save_all_preprocessing_plots(self) -> Dict[str, str]: """ Save all preprocessing plots from current session Returns: -------- Dict[str, str] : dictionary with paths to saved plots """ logger.info("Saving all preprocessing plots...") plots_saved = {} # Get all open figures figure_numbers = plt.get_fignums() if not figure_numbers: logger.warning("No open plots to save") return plots_saved # Save each plot for fig_num in figure_numbers: fig = plt.figure(fig_num) filename = f"preprocessing_plot_{fig_num}.png" filepath = self._save_figure(fig, filename, "preprocessing") if filepath: plots_saved[f"plot_{fig_num}"] = filepath logger.info(f"Saved {len(plots_saved)} preprocessing plots") return plots_saved def create_all_visualizations( self, data: pd.DataFrame, processed_data: pd.DataFrame = None, feature_importance: Dict = None, decomposition_result: Dict = None, validation_results: Dict = None, preprocessing_stages: Dict = None ) -> Dict[str, str]: """ Create all visualisations in one call Parameters: ----------- data : pd.DataFrame Original data processed_data : pd.DataFrame, optional Processed data feature_importance : Dict, optional Feature importance decomposition_result : Dict, optional Decomposition results validation_results : Dict, optional Validation results preprocessing_stages : Dict, optional Preprocessing stages Returns: -------- Dict[str, str] : dictionary with paths to created plots """ logger.info("\n" + "="*80) logger.info("STARTING ALL VISUALISATIONS CREATION") logger.info("="*80) result_files = {} # 1. Summary dashboard if data is not None: logger.info("Creating summary dashboard...") summary_path = self.create_summary_dashboard(data, preprocessing_stages) if summary_path: result_files['summary'] = summary_path # 2. Correlation heatmaps if data is not None: logger.info("Creating correlation heatmaps...") main_corr, target_corr = self.create_correlation_heatmap(data) if main_corr: result_files['correlation_main'] = main_corr if target_corr: result_files['correlation_target'] = target_corr # 3. Distribution comparison if data is not None and processed_data is not None: logger.info("Creating distribution comparison...") dist_path = self.create_distribution_comparison(data, processed_data) if dist_path: result_files['distribution'] = dist_path # 4. Feature importance if feature_importance: logger.info("Creating feature importance plot...") feat_path = self.create_feature_importance_plot(feature_importance) if feat_path: result_files['feature_importance'] = feat_path # 5. Time series decomposition if decomposition_result: logger.info("Creating time series decomposition...") decomp_path = self.create_time_series_decomposition_plot(decomposition_result) if decomp_path: result_files['decomposition'] = decomp_path # 6. Data quality report if validation_results: logger.info("Creating data quality report...") quality_path = self.create_data_quality_report(validation_results) if quality_path: result_files['quality_report'] = quality_path # Save information about all plots self.save_plots_info() logger.info("\n" + "="*80) logger.info("VISUALISATIONS SUCCESSFULLY CREATED") logger.info("="*80) for plot_name, plot_path in result_files.items(): if plot_path: logger.info(f"✓ {plot_name}: {plot_path}") return result_files def get_all_plots(self) -> Dict: """Get information about all created plots""" return self.plot_files def save_plots_info(self, filename: str = "plots_info.json") -> None: """Save plot information to JSON file""" try: plots_info = { 'total_plots': len(self.plot_files), 'plots': self.plot_files, 'directories': { 'correlations': self.correlations_dir, 'distributions': self.distributions_dir, 'features': self.features_dir, 'time_series': self.time_series_dir, 'preprocessing': self.preprocessing_dir, 'summary': self.summary_dir, 'reports': self.reports_dir }, 'generation_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'config': { 'target_column': self.config.target_column, 'results_dir': self.config.results_dir } } filepath = os.path.join(self.reports_dir, filename) with open(filepath, 'w', encoding='utf-8') as f: json.dump(plots_info, f, indent=4, ensure_ascii=False, default=str) logger.info(f"✓ Plot information saved: {filepath}") except Exception as e: logger.error(f"✗ Error saving plot information: {e}") def move_existing_plots(self, source_dir: str = None) -> Dict[str, str]: """ Move existing plots from specified directory to structured folders Parameters: ----------- source_dir : str, optional Directory with existing plots Returns: -------- Dict[str, str] : dictionary with information about moved files """ if source_dir is None: source_dir = self.plots_dir if not os.path.exists(source_dir): logger.warning(f"Source directory doesn't exist: {source_dir}") return {} # File to folder mapping file_to_folder_map = { # Time series 'data_split.png': 'time_series', 'stationarity_raskhodvoda.png': 'time_series', 'stationarity_analysis.png': 'time_series', 'temporal_outliers.png': 'time_series', # Correlations 'feature_selection_correlation.png': 'correlations', # Preprocessing 'missing_values_analysis.png': 'preprocessing', 'outlier_handling_results.png': 'preprocessing', 'outliers_analysis.png': 'preprocessing', 'scaling_results.png': 'preprocessing', # Default 'default': 'summary' } moved_files = {} for filename in os.listdir(source_dir): if filename.endswith('.png'): source_path = os.path.join(source_dir, filename) # Determine destination folder target_folder = file_to_folder_map.get(filename, file_to_folder_map['default']) target_dir = os.path.join(self.plots_dir, target_folder) # Create destination folder if doesn't exist os.makedirs(target_dir, exist_ok=True) # Target path target_path = os.path.join(target_dir, filename) try: # Move file os.rename(source_path, target_path) moved_files[filename] = target_path logger.info(f"Moved: {filename} -> {target_folder}/") except Exception as e: logger.error(f"Error moving {filename}: {e}") logger.info(f"Moved {len(moved_files)} files") return moved_files