Spaces:
Sleeping
Sleeping
| # ============================================ | |
| # CLASS 13: VISUALISATION MANAGER (UPDATED) | |
| # ============================================ | |
| import os | |
| from datetime import datetime | |
| import json | |
| from typing import Dict, List, Optional, Tuple, Union, Any | |
| import pandas as pd | |
| import numpy as np | |
| from statsmodels.graphics.tsaplots import plot_acf, plot_pacf | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from scipy.stats import gaussian_kde | |
| import matplotlib | |
| matplotlib.use('Agg') # Use non-display backend | |
| from config.config import Config | |
| import logging | |
| # Logging setup | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class VisualisationManager: | |
| """Class for managing all visualisations""" | |
| def __init__(self, config: Config): | |
| """ | |
| Initialise visualisation manager | |
| Parameters: | |
| ----------- | |
| config : Config | |
| Experiment configuration | |
| """ | |
| self.config = config | |
| self.plots_generated = {} | |
| self.plot_files = {} | |
| self.figure_count = 0 | |
| # Create directory structure for saving plots | |
| self._create_directory_structure() | |
| def _create_directory_structure(self) -> None: | |
| """Create directory structure for saving plots""" | |
| base_dir = self.config.results_dir | |
| # Main plot directories | |
| self.plots_dir = os.path.join(base_dir, "plots") | |
| self.correlations_dir = os.path.join(base_dir, "plots", "correlations") | |
| self.distributions_dir = os.path.join(base_dir, "plots", "distributions") | |
| self.features_dir = os.path.join(base_dir, "plots", "features") | |
| self.time_series_dir = os.path.join(base_dir, "plots", "time_series") | |
| self.preprocessing_dir = os.path.join(base_dir, "plots", "preprocessing") | |
| self.summary_dir = os.path.join(base_dir, "plots", "summary") | |
| self.reports_dir = os.path.join(base_dir, "reports") | |
| # Create directories | |
| directories = [ | |
| self.plots_dir, | |
| self.correlations_dir, | |
| self.distributions_dir, | |
| self.features_dir, | |
| self.time_series_dir, | |
| self.preprocessing_dir, | |
| self.summary_dir, | |
| self.reports_dir | |
| ] | |
| for directory in directories: | |
| os.makedirs(directory, exist_ok=True) | |
| logger.debug(f"Created directory: {directory}") | |
| def _save_figure(self, fig: plt.Figure, filename: str, | |
| subdirectory: str = None, dpi: int = 300) -> str: | |
| """ | |
| Save plot and close it | |
| Parameters: | |
| ----------- | |
| fig : matplotlib.figure.Figure | |
| Plot figure object | |
| filename : str | |
| Filename for saving | |
| subdirectory : str, optional | |
| Subdirectory for saving | |
| dpi : int | |
| Save quality | |
| Returns: | |
| -------- | |
| str : full path to saved file | |
| """ | |
| if not filename.endswith('.png'): | |
| filename = f"{filename}.png" | |
| if subdirectory: | |
| save_dir = os.path.join(self.plots_dir, subdirectory) | |
| os.makedirs(save_dir, exist_ok=True) | |
| else: | |
| save_dir = self.plots_dir | |
| filepath = os.path.join(save_dir, filename) | |
| try: | |
| fig.savefig(filepath, dpi=dpi, bbox_inches='tight', facecolor='white') | |
| logger.info(f"✓ Plot saved: {filepath}") | |
| except Exception as e: | |
| logger.error(f"✗ Error saving plot {filename}: {e}") | |
| filepath = None | |
| # Close plot without display | |
| plt.close(fig) | |
| return filepath | |
| # ============================================ | |
| # MAIN VISUALISATION METHODS | |
| # ============================================ | |
| def create_summary_dashboard( | |
| self, | |
| data: pd.DataFrame, | |
| preprocessing_stages: Dict = None, | |
| filename: str = "summary_dashboard" | |
| ) -> str: | |
| """ | |
| Create summary visualisation dashboard | |
| Parameters: | |
| ----------- | |
| data : pd.DataFrame | |
| Data for visualisation | |
| preprocessing_stages : Dict, optional | |
| Preprocessing stages information | |
| filename : str | |
| Filename for saving | |
| Returns: | |
| -------- | |
| str : path to saved file or None if error | |
| """ | |
| logger.info("\n" + "="*80) | |
| logger.info("CREATING SUMMARY DASHBOARD") | |
| logger.info("="*80) | |
| target_col = self.config.target_column | |
| try: | |
| # Create large dashboard | |
| fig = plt.figure(figsize=(20, 24)) | |
| gs = fig.add_gridspec(6, 4, hspace=0.3, wspace=0.3) | |
| # 1. Time series of target variable | |
| ax1 = fig.add_subplot(gs[0, :2]) | |
| if target_col in data.columns and isinstance(data.index, pd.DatetimeIndex): | |
| ax1.plot(data.index, data[target_col], linewidth=1, color='blue', alpha=0.7) | |
| ax1.set_title(f'Time Series: {target_col}', fontsize=12, fontweight='bold') | |
| ax1.set_xlabel('Date', fontsize=10) | |
| ax1.set_ylabel(target_col, fontsize=10) | |
| ax1.grid(True, alpha=0.3) | |
| ax1.tick_params(axis='x', rotation=45) | |
| else: | |
| ax1.text(0.5, 0.5, 'No time series data available', | |
| ha='center', va='center', transform=ax1.transAxes) | |
| # 2. Target variable distribution | |
| ax2 = fig.add_subplot(gs[0, 2:]) | |
| if target_col in data.columns: | |
| values = data[target_col].dropna() | |
| if len(values) > 0: | |
| ax2.hist(values, bins=30, edgecolor='black', alpha=0.7, color='green') | |
| ax2.set_title(f'Distribution: {target_col}', fontsize=12, fontweight='bold') | |
| ax2.set_xlabel(target_col, fontsize=10) | |
| ax2.set_ylabel('Frequency', fontsize=10) | |
| ax2.grid(True, alpha=0.3) | |
| else: | |
| ax2.text(0.5, 0.5, 'No data for distribution', | |
| ha='center', va='center', transform=ax2.transAxes) | |
| # 3. Correlation matrix (top features) | |
| ax3 = fig.add_subplot(gs[1, :]) | |
| numeric_cols = data.select_dtypes(include=[np.number]).columns | |
| if len(numeric_cols) > 1: | |
| display_cols = list(numeric_cols[:15]) | |
| if target_col not in display_cols and target_col in data.columns: | |
| display_cols = [target_col] + [c for c in display_cols if c != target_col][:14] | |
| corr_matrix = data[display_cols].corr() | |
| mask = np.triu(np.ones_like(corr_matrix, dtype=bool)) | |
| im = ax3.imshow(corr_matrix.where(~mask), cmap='coolwarm', vmin=-1, vmax=1, aspect='auto') | |
| ax3.set_title('Correlation Matrix (Top 15 Features)', | |
| fontsize=12, fontweight='bold') | |
| ax3.set_xticks(range(len(display_cols))) | |
| ax3.set_yticks(range(len(display_cols))) | |
| ax3.set_xticklabels(display_cols, rotation=90, fontsize=8) | |
| ax3.set_yticklabels(display_cols, fontsize=8) | |
| plt.colorbar(im, ax=ax3, shrink=0.8) | |
| # 4. Seasonal patterns | |
| ax4 = fig.add_subplot(gs[2, :2]) | |
| if target_col in data.columns and isinstance(data.index, pd.DatetimeIndex): | |
| data_copy = data.copy() | |
| data_copy['month'] = data_copy.index.month | |
| monthly_avg = data_copy.groupby('month')[target_col].mean() | |
| colors = plt.cm.Set3(np.linspace(0, 1, len(monthly_avg))) | |
| ax4.bar(monthly_avg.index, monthly_avg.values, color=colors, edgecolor='black') | |
| ax4.set_title('Average Values by Month', fontsize=12, fontweight='bold') | |
| ax4.set_xlabel('Month', fontsize=10) | |
| ax4.set_ylabel(f'Average {target_col}', fontsize=10) | |
| ax4.set_xticks(range(1, 13)) | |
| month_names = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', | |
| 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'] | |
| ax4.set_xticklabels(month_names) | |
| ax4.grid(True, alpha=0.3, axis='y') | |
| # 5. Weekly patterns | |
| ax5 = fig.add_subplot(gs[2, 2:]) | |
| if target_col in data.columns and isinstance(data.index, pd.DatetimeIndex): | |
| data_copy = data.copy() | |
| data_copy['dayofweek'] = data_copy.index.dayofweek | |
| daily_avg = data_copy.groupby('dayofweek')[target_col].mean() | |
| colors = plt.cm.Paired(np.linspace(0, 1, len(daily_avg))) | |
| ax5.bar(daily_avg.index, daily_avg.values, color=colors, edgecolor='black') | |
| ax5.set_title('Average Values by Day of Week', fontsize=12, fontweight='bold') | |
| ax5.set_xlabel('Day of Week', fontsize=10) | |
| ax5.set_ylabel(f'Average {target_col}', fontsize=10) | |
| ax5.set_xticks(range(7)) | |
| ax5.set_xticklabels(['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']) | |
| ax5.grid(True, alpha=0.3, axis='y') | |
| # 6. Trend and seasonality | |
| ax6 = fig.add_subplot(gs[3, :]) | |
| if target_col in data.columns and len(data) > 30: | |
| try: | |
| window_size = min(365, len(data) // 10) | |
| if window_size >= 7: | |
| rolling_mean = data[target_col].rolling(window=window_size, center=True).mean() | |
| rolling_std = data[target_col].rolling(window=window_size, center=True).std() | |
| ax6.plot(data.index, data[target_col], alpha=0.5, | |
| label='Original Series', linewidth=0.5, color='blue') | |
| ax6.plot(rolling_mean.index, rolling_mean, | |
| label=f'Rolling Mean ({window_size} days)', | |
| color='red', linewidth=2) | |
| ax6.fill_between(rolling_mean.index, | |
| rolling_mean - rolling_std, | |
| rolling_mean + rolling_std, | |
| alpha=0.2, color='red') | |
| ax6.set_title('Trend and Volatility', fontsize=12, fontweight='bold') | |
| ax6.set_xlabel('Date', fontsize=10) | |
| ax6.set_ylabel(target_col, fontsize=10) | |
| ax6.legend(fontsize=9, loc='upper left') | |
| ax6.grid(True, alpha=0.3) | |
| else: | |
| ax6.text(0.5, 0.5, 'Insufficient data for trend analysis', | |
| ha='center', va='center', transform=ax6.transAxes) | |
| except Exception as e: | |
| logger.warning(f"Error plotting trend: {e}") | |
| ax6.text(0.5, 0.5, 'Error plotting trend', | |
| ha='center', va='center', transform=ax6.transAxes) | |
| # 7. Preprocessing statistics | |
| if preprocessing_stages: | |
| ax7 = fig.add_subplot(gs[4, :2]) | |
| stages = list(preprocessing_stages.keys()) | |
| values = list(preprocessing_stages.values()) | |
| colors = plt.cm.viridis(np.linspace(0.3, 0.9, len(stages))) | |
| bars = ax7.bar(range(len(stages)), values, color=colors, edgecolor='black') | |
| ax7.set_title('Preprocessing Statistics', fontsize=12, fontweight='bold') | |
| ax7.set_xlabel('Processing Stage', fontsize=10) | |
| ax7.set_ylabel('Value', fontsize=10) | |
| ax7.set_xticks(range(len(stages))) | |
| ax7.set_xticklabels([s[:15] + '...' if len(s) > 15 else s for s in stages], | |
| rotation=45, ha='right', fontsize=9) | |
| ax7.grid(True, alpha=0.3, axis='y') | |
| # Add values on bars | |
| for bar, value in zip(bars, values): | |
| height = bar.get_height() | |
| ax7.text(bar.get_x() + bar.get_width()/2., height, | |
| f'{value:.2f}', ha='center', va='bottom', fontsize=8) | |
| # 8. Data information | |
| ax8 = fig.add_subplot(gs[4, 2:]) | |
| ax8.axis('off') | |
| info_text = [] | |
| info_text.append("GENERAL CHARACTERISTICS:") | |
| info_text.append(f"• Number of records: {len(data):,}") | |
| info_text.append(f"• Number of features: {len(data.columns)}") | |
| if isinstance(data.index, pd.DatetimeIndex): | |
| info_text.append(f"• Period: {data.index.min().strftime('%Y-%m-%d')} - " | |
| f"{data.index.max().strftime('%Y-%m-%d')}") | |
| info_text.append(f"• Days of data: {(data.index.max() - data.index.min()).days}") | |
| if target_col in data.columns: | |
| target_stats = data[target_col].describe() | |
| info_text.append(f"\nTARGET VARIABLE '{target_col}':") | |
| info_text.append(f"• Mean: {target_stats['mean']:.2f}") | |
| info_text.append(f"• Standard deviation: {target_stats['std']:.2f}") | |
| info_text.append(f"• Minimum: {target_stats['min']:.2f}") | |
| info_text.append(f"• 25%: {target_stats['25%']:.2f}") | |
| info_text.append(f"• 50% (median): {target_stats['50%']:.2f}") | |
| info_text.append(f"• 75%: {target_stats['75%']:.2f}") | |
| info_text.append(f"• Maximum: {target_stats['max']:.2f}") | |
| info_text.append(f"\nDATA TYPES:") | |
| for dtype, count in data.dtypes.value_counts().items(): | |
| info_text.append(f"• {dtype}: {count} columns") | |
| missing_info = data.isnull().sum() | |
| missing_total = missing_info.sum() | |
| missing_percent = missing_total / data.size * 100 | |
| info_text.append(f"\nMISSING VALUES:") | |
| info_text.append(f"• Total missing: {missing_total:,}") | |
| info_text.append(f"• Missing percentage: {missing_percent:.2f}%") | |
| if missing_total > 0: | |
| top_missing = missing_info.nlargest(5) | |
| info_text.append(f"• Top 5 columns with missing values:") | |
| for col, count in top_missing.items(): | |
| percent = count / len(data) * 100 | |
| info_text.append(f" {col}: {count} ({percent:.1f}%)") | |
| ax8.text(0.02, 0.98, '\n'.join(info_text), transform=ax8.transAxes, | |
| fontsize=8, verticalalignment='top', fontfamily='monospace') | |
| # 9. Autocorrelation plot | |
| ax9 = fig.add_subplot(gs[5, :2]) | |
| if target_col in data.columns: | |
| try: | |
| series = data[target_col].dropna() | |
| if len(series) > 50: | |
| plot_acf(series, lags=min(50, len(series)-1), ax=ax9, alpha=0.05) | |
| ax9.set_title('Autocorrelation Function (ACF)', fontsize=12, fontweight='bold') | |
| ax9.set_xlabel('Lag', fontsize=10) | |
| ax9.set_ylabel('Autocorrelation', fontsize=10) | |
| ax9.grid(True, alpha=0.3) | |
| else: | |
| ax9.text(0.5, 0.5, 'Insufficient data for ACF', | |
| ha='center', va='center', transform=ax9.transAxes) | |
| except Exception as e: | |
| logger.warning(f"Error plotting ACF: {e}") | |
| ax9.text(0.5, 0.5, 'Error calculating ACF', | |
| ha='center', va='center', transform=ax9.transAxes) | |
| # 10. Partial autocorrelation plot | |
| ax10 = fig.add_subplot(gs[5, 2:]) | |
| if target_col in data.columns: | |
| try: | |
| series = data[target_col].dropna() | |
| if len(series) > 50: | |
| plot_pacf(series, lags=min(50, len(series)-1), ax=ax10, alpha=0.05) | |
| ax10.set_title('Partial Autocorrelation Function (PACF)', | |
| fontsize=12, fontweight='bold') | |
| ax10.set_xlabel('Lag', fontsize=10) | |
| ax10.set_ylabel('Partial Autocorrelation', fontsize=10) | |
| ax10.grid(True, alpha=0.3) | |
| else: | |
| ax10.text(0.5, 0.5, 'Insufficient data for PACF', | |
| ha='center', va='center', transform=ax10.transAxes) | |
| except Exception as e: | |
| logger.warning(f"Error plotting PACF: {e}") | |
| ax10.text(0.5, 0.5, 'Error calculating PACF', | |
| ha='center', va='center', transform=ax10.transAxes) | |
| plt.suptitle('Data Analysis Summary Dashboard', fontsize=16, fontweight='bold', y=0.98) | |
| plt.tight_layout() | |
| # Save | |
| filepath = self._save_figure(fig, filename, "summary") | |
| self.plot_files['summary_dashboard'] = filepath | |
| return filepath | |
| except Exception as e: | |
| logger.error(f"Error creating summary dashboard: {e}") | |
| return None | |
| # ============================================ | |
| # SPECIFIC METHODS FOR SAVING YOUR PLOTS | |
| # ============================================ | |
| def save_data_split_plot(self, filename: str = "data_split.png") -> str: | |
| """ | |
| Save data split plot | |
| Parameters: | |
| ----------- | |
| filename : str | |
| Filename for saving | |
| Returns: | |
| -------- | |
| str : path to saved file | |
| """ | |
| try: | |
| fig = plt.gcf() # Get current figure | |
| filepath = self._save_figure(fig, filename, "time_series") | |
| self.plot_files['data_split'] = filepath | |
| return filepath | |
| except Exception as e: | |
| logger.error(f"Error saving data_split plot: {e}") | |
| return None | |
| def save_feature_selection_correlation_plot(self, filename: str = "feature_selection_correlation.png") -> str: | |
| """ | |
| Save feature selection correlation plot | |
| Parameters: | |
| ----------- | |
| filename : str | |
| Filename for saving | |
| Returns: | |
| -------- | |
| str : path to saved file | |
| """ | |
| try: | |
| fig = plt.gcf() # Get current figure | |
| filepath = self._save_figure(fig, filename, "correlations") | |
| self.plot_files['feature_selection_correlation'] = filepath | |
| return filepath | |
| except Exception as e: | |
| logger.error(f"Error saving feature_selection_correlation plot: {e}") | |
| return None | |
| def save_missing_values_analysis_plot(self, filename: str = "missing_values_analysis.png") -> str: | |
| """ | |
| Save missing values analysis plot | |
| Parameters: | |
| ----------- | |
| filename : str | |
| Filename for saving | |
| Returns: | |
| -------- | |
| str : path to saved file | |
| """ | |
| try: | |
| fig = plt.gcf() # Get current figure | |
| filepath = self._save_figure(fig, filename, "preprocessing") | |
| self.plot_files['missing_values_analysis'] = filepath | |
| return filepath | |
| except Exception as e: | |
| logger.error(f"Error saving missing_values_analysis plot: {e}") | |
| return None | |
| def save_outlier_handling_results_plot(self, filename: str = "outlier_handling_results.png") -> str: | |
| """ | |
| Save outlier handling results plot | |
| Parameters: | |
| ----------- | |
| filename : str | |
| Filename for saving | |
| Returns: | |
| -------- | |
| str : path to saved file | |
| """ | |
| try: | |
| fig = plt.gcf() # Get current figure | |
| filepath = self._save_figure(fig, filename, "preprocessing") | |
| self.plot_files['outlier_handling_results'] = filepath | |
| return filepath | |
| except Exception as e: | |
| logger.error(f"Error saving outlier_handling_results plot: {e}") | |
| return None | |
| def save_outliers_analysis_plot(self, filename: str = "outliers_analysis.png") -> str: | |
| """ | |
| Save outliers analysis plot | |
| Parameters: | |
| ----------- | |
| filename : str | |
| Filename for saving | |
| Returns: | |
| -------- | |
| str : path to saved file | |
| """ | |
| try: | |
| fig = plt.gcf() # Get current figure | |
| filepath = self._save_figure(fig, filename, "preprocessing") | |
| self.plot_files['outliers_analysis'] = filepath | |
| return filepath | |
| except Exception as e: | |
| logger.error(f"Error saving outliers_analysis plot: {e}") | |
| return None | |
| def save_scaling_results_plot(self, filename: str = "scaling_results.png") -> str: | |
| """ | |
| Save scaling results plot | |
| Parameters: | |
| ----------- | |
| filename : str | |
| Filename for saving | |
| Returns: | |
| -------- | |
| str : path to saved file | |
| """ | |
| try: | |
| fig = plt.gcf() # Get current figure | |
| filepath = self._save_figure(fig, filename, "preprocessing") | |
| self.plot_files['scaling_results'] = filepath | |
| return filepath | |
| except Exception as e: | |
| logger.error(f"Error saving scaling_results plot: {e}") | |
| return None | |
| def save_stationarity_analysis_plot(self, filename: str = "stationarity_analysis.png") -> str: | |
| """ | |
| Save stationarity analysis plot | |
| Parameters: | |
| ----------- | |
| filename : str | |
| Filename for saving | |
| Returns: | |
| -------- | |
| str : path to saved file | |
| """ | |
| try: | |
| fig = plt.gcf() # Get current figure | |
| filepath = self._save_figure(fig, filename, "time_series") | |
| self.plot_files['stationarity_analysis'] = filepath | |
| return filepath | |
| except Exception as e: | |
| logger.error(f"Error saving stationarity_analysis plot: {e}") | |
| return None | |
| def save_temporal_outliers_plot(self, filename: str = "temporal_outliers.png") -> str: | |
| """ | |
| Save temporal outliers plot | |
| Parameters: | |
| ----------- | |
| filename : str | |
| Filename for saving | |
| Returns: | |
| -------- | |
| str : path to saved file | |
| """ | |
| try: | |
| fig = plt.gcf() # Get current figure | |
| filepath = self._save_figure(fig, filename, "time_series") | |
| self.plot_files['temporal_outliers'] = filepath | |
| return filepath | |
| except Exception as e: | |
| logger.error(f"Error saving temporal_outliers plot: {e}") | |
| return None | |
| # ============================================ | |
| # UNIVERSAL METHOD FOR SAVING ANY PLOT | |
| # ============================================ | |
| def save_current_plot(self, filename: str, subdirectory: str = None) -> str: | |
| """ | |
| Universal method for saving current plot | |
| Parameters: | |
| ----------- | |
| filename : str | |
| Filename for saving | |
| subdirectory : str, optional | |
| Subdirectory for saving | |
| Returns: | |
| -------- | |
| str : path to saved file | |
| """ | |
| try: | |
| fig = plt.gcf() # Get current figure | |
| filepath = self._save_figure(fig, filename, subdirectory) | |
| # Save plot information | |
| plot_key = filename.replace('.png', '').replace('.jpg', '') | |
| self.plot_files[plot_key] = filepath | |
| return filepath | |
| except Exception as e: | |
| logger.error(f"Error saving plot {filename}: {e}") | |
| return None | |
| # ============================================ | |
| # ADDITIONAL VISUALISATION METHODS | |
| # ============================================ | |
| def create_feature_importance_plot( | |
| self, | |
| feature_importance: Dict, | |
| top_n: int = 20, | |
| filename: str = "feature_importance" | |
| ) -> str: | |
| """ | |
| Create feature importance plot | |
| Parameters: | |
| ----------- | |
| feature_importance : Dict | |
| Dictionary with feature importance | |
| top_n : int | |
| Number of top features to display | |
| filename : str | |
| Filename for saving | |
| Returns: | |
| -------- | |
| str : path to saved file or None if error | |
| """ | |
| if not feature_importance: | |
| logger.warning("No feature importance data for visualisation") | |
| return None | |
| try: | |
| # Convert to Series and sort | |
| importance_series = pd.Series(feature_importance).sort_values(ascending=False) | |
| top_features = importance_series.head(top_n) | |
| # Create plot | |
| fig, ax = plt.subplots(figsize=(12, 8)) | |
| y_pos = np.arange(len(top_features)) | |
| colors = plt.cm.plasma(np.linspace(0.2, 0.9, len(top_features))) | |
| bars = ax.barh(y_pos, top_features.values, color=colors, edgecolor='black') | |
| ax.set_yticks(y_pos) | |
| ax.set_yticklabels(top_features.index, fontsize=10) | |
| ax.invert_yaxis() | |
| ax.set_xlabel('Feature Importance', fontsize=11, fontweight='bold') | |
| ax.set_title(f'Top-{top_n} Most Important Features', fontsize=14, fontweight='bold') | |
| ax.grid(True, alpha=0.3, axis='x') | |
| # Add values on bars | |
| for i, (bar, value) in enumerate(zip(bars, top_features.values)): | |
| width = bar.get_width() | |
| ax.text(width * 1.01, bar.get_y() + bar.get_height()/2, | |
| f'{value:.4f}', va='center', fontsize=9, fontweight='bold') | |
| # Add additional information | |
| plt.text(0.02, 0.98, f'Total features: {len(importance_series)}', | |
| transform=fig.transFigure, fontsize=9, verticalalignment='top') | |
| plt.tight_layout() | |
| # Save | |
| filepath = self._save_figure(fig, filename, "features") | |
| self.plot_files['feature_importance'] = filepath | |
| return filepath | |
| except Exception as e: | |
| logger.error(f"Error creating feature importance plot: {e}") | |
| return None | |
| def create_correlation_heatmap( | |
| self, | |
| data: pd.DataFrame, | |
| top_n: int = 20, | |
| filename: str = "correlation_heatmap" | |
| ) -> Tuple[str, Optional[str]]: | |
| """ | |
| Create correlation heatmap | |
| Parameters: | |
| ----------- | |
| data : pd.DataFrame | |
| Data for analysis | |
| top_n : int | |
| Number of top features to display | |
| filename : str | |
| Filename for saving | |
| Returns: | |
| -------- | |
| Tuple[str, Optional[str]]: | |
| (path to main heatmap, path to target correlation heatmap) | |
| """ | |
| target_col = self.config.target_column | |
| try: | |
| numeric_cols = data.select_dtypes(include=[np.number]).columns.tolist() | |
| if len(numeric_cols) < 2: | |
| logger.warning("Insufficient numeric features for correlation analysis") | |
| return None, None | |
| # Create two heatmaps | |
| # 1. Main correlation heatmap between all features | |
| main_filepath = self._create_main_correlation_heatmap(data, numeric_cols, top_n, filename) | |
| # 2. Target correlation heatmap | |
| target_filepath = None | |
| if target_col in data.columns and target_col in numeric_cols: | |
| target_filepath = self._create_target_correlation_heatmap(data, target_col, numeric_cols, filename) | |
| return main_filepath, target_filepath | |
| except Exception as e: | |
| logger.error(f"Error creating correlation heatmap: {e}") | |
| return None, None | |
| def _create_main_correlation_heatmap( | |
| self, | |
| data: pd.DataFrame, | |
| numeric_cols: List[str], | |
| top_n: int, | |
| filename: str | |
| ) -> str: | |
| """Create main correlation heatmap""" | |
| # Limit number of features for better readability | |
| if len(numeric_cols) > top_n: | |
| # Select features with highest variance | |
| variances = data[numeric_cols].var().sort_values(ascending=False) | |
| selected_cols = variances.head(top_n).index.tolist() | |
| else: | |
| selected_cols = numeric_cols | |
| # Calculate correlation | |
| corr_matrix = data[selected_cols].corr() | |
| fig, ax = plt.subplots(figsize=(14, 12)) | |
| # Mask for upper triangle | |
| mask = np.triu(np.ones_like(corr_matrix, dtype=bool)) | |
| # Create heatmap | |
| sns.heatmap( | |
| corr_matrix, | |
| annot=True, | |
| fmt='.2f', | |
| cmap='coolwarm', | |
| center=0, | |
| square=True, | |
| mask=mask, | |
| cbar_kws={'shrink': 0.8, 'label': 'Correlation Coefficient'}, | |
| linewidths=0.5, | |
| linecolor='white', | |
| ax=ax, | |
| annot_kws={'size': 8} | |
| ) | |
| ax.set_title(f'Correlation Matrix Between Features (Top-{top_n})', | |
| fontsize=14, fontweight='bold', pad=20) | |
| plt.tight_layout() | |
| # Save | |
| filepath = self._save_figure(fig, filename, "correlations") | |
| self.plot_files['correlation_heatmap_main'] = filepath | |
| return filepath | |
| def _create_target_correlation_heatmap( | |
| self, | |
| data: pd.DataFrame, | |
| target_col: str, | |
| numeric_cols: List[str], | |
| filename: str | |
| ) -> str: | |
| """Create target correlation heatmap""" | |
| # Calculate correlations with target variable | |
| correlations = data[numeric_cols].corrwith(data[target_col]).sort_values(key=abs, ascending=False) | |
| # Exclude target variable itself | |
| correlations = correlations[correlations.index != target_col] | |
| # Take top 15 features | |
| top_features = correlations.head(15) | |
| fig, ax = plt.subplots(figsize=(10, 8)) | |
| colors = ['red' if x < 0 else 'green' for x in top_features.values] | |
| bars = ax.barh(range(len(top_features)), top_features.values, color=colors, edgecolor='black') | |
| ax.set_yticks(range(len(top_features))) | |
| ax.set_yticklabels(top_features.index, fontsize=10) | |
| ax.invert_yaxis() | |
| ax.set_xlabel('Correlation Coefficient', fontsize=11, fontweight='bold') | |
| ax.set_title(f'Feature Correlations with Target Variable "{target_col}"', | |
| fontsize=14, fontweight='bold', pad=20) | |
| ax.grid(True, alpha=0.3, axis='x') | |
| ax.axvline(x=0, color='black', linestyle='-', linewidth=0.5) | |
| # Add values on bars | |
| for bar, value in zip(bars, top_features.values): | |
| width = bar.get_width() | |
| ax.text(width + (0.01 if width >= 0 else -0.04), | |
| bar.get_y() + bar.get_height()/2, | |
| f'{value:.3f}', | |
| va='center', | |
| ha='left' if width >= 0 else 'right', | |
| fontsize=9, | |
| fontweight='bold', | |
| color='black') | |
| plt.tight_layout() | |
| # Save | |
| target_filename = f"{filename}_with_target" | |
| filepath = self._save_figure(fig, target_filename, "correlations") | |
| self.plot_files['correlation_with_target'] = filepath | |
| return filepath | |
| def create_distribution_comparison( | |
| self, | |
| original_data: pd.DataFrame, | |
| processed_data: pd.DataFrame, | |
| columns: List[str] = None, | |
| max_columns: int = 12, | |
| filename: str = "distribution_comparison" | |
| ) -> str: | |
| """ | |
| Compare distributions before and after processing | |
| Parameters: | |
| ----------- | |
| original_data : pd.DataFrame | |
| Original data | |
| processed_data : pd.DataFrame | |
| Processed data | |
| columns : List[str], optional | |
| List of columns to compare | |
| max_columns : int | |
| Maximum number of columns to display | |
| filename : str | |
| Filename for saving | |
| Returns: | |
| -------- | |
| str : path to saved file or None if error | |
| """ | |
| try: | |
| if columns is None: | |
| # Select numeric columns common to both datasets | |
| numeric_cols_original = original_data.select_dtypes(include=[np.number]).columns | |
| numeric_cols_processed = processed_data.select_dtypes(include=[np.number]).columns | |
| common_cols = list(set(numeric_cols_original) & set(numeric_cols_processed)) | |
| # Sort by variance in original data | |
| variances = original_data[common_cols].var().sort_values(ascending=False) | |
| columns = variances.head(max_columns).index.tolist() | |
| n_cols = min(4, len(columns)) | |
| n_rows = (len(columns) + n_cols - 1) // n_cols | |
| fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 4, n_rows * 3.5)) | |
| fig.suptitle('Distribution Comparison Before and After Processing', | |
| fontsize=16, fontweight='bold', y=0.98) | |
| if n_rows == 1 and n_cols == 1: | |
| axes = np.array([axes]) | |
| axes = axes.flat if hasattr(axes, 'flat') else [axes] | |
| for idx, col in enumerate(columns): | |
| if idx >= len(axes): | |
| break | |
| ax = axes[idx] | |
| if col in original_data.columns and col in processed_data.columns: | |
| original_values = original_data[col].dropna() | |
| processed_values = processed_data[col].dropna() | |
| if len(original_values) > 0 and len(processed_values) > 0: | |
| # Use common bins for comparison | |
| all_values = pd.concat([original_values, processed_values]) | |
| bins = np.histogram_bin_edges(all_values, bins=30) | |
| # Histograms | |
| ax.hist(original_values, bins=bins, alpha=0.5, | |
| label='Before Processing', density=True, color='blue') | |
| ax.hist(processed_values, bins=bins, alpha=0.5, | |
| label='After Processing', density=True, color='orange') | |
| # Add KDE | |
| try: | |
| if len(original_values) > 10: | |
| kde_original = gaussian_kde(original_values) | |
| x_range = np.linspace(original_values.min(), original_values.max(), 100) | |
| ax.plot(x_range, kde_original(x_range), 'b-', linewidth=1.5, alpha=0.8) | |
| if len(processed_values) > 10: | |
| kde_processed = gaussian_kde(processed_values) | |
| x_range = np.linspace(processed_values.min(), processed_values.max(), 100) | |
| ax.plot(x_range, kde_processed(x_range), 'orange', linewidth=1.5, alpha=0.8) | |
| except: | |
| pass | |
| # Add statistics | |
| stats_text = [] | |
| if len(original_values) > 0: | |
| stats_text.append(f"Before: μ={original_values.mean():.2f}, σ={original_values.std():.2f}") | |
| if len(processed_values) > 0: | |
| stats_text.append(f"After: μ={processed_values.mean():.2f}, σ={processed_values.std():.2f}") | |
| if stats_text: | |
| ax.text(0.02, 0.98, '\n'.join(stats_text), | |
| transform=ax.transAxes, fontsize=8, | |
| verticalalignment='top', | |
| bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5)) | |
| ax.set_title(f'{col}', fontsize=11, fontweight='bold') | |
| ax.set_xlabel('Value', fontsize=9) | |
| ax.set_ylabel('Density', fontsize=9) | |
| ax.legend(fontsize=8) | |
| ax.grid(True, alpha=0.3) | |
| else: | |
| ax.text(0.5, 0.5, 'No data', | |
| ha='center', va='center', transform=ax.transAxes) | |
| else: | |
| ax.text(0.5, 0.5, 'Column not found', | |
| ha='center', va='center', transform=ax.transAxes) | |
| # Hide unused subplots | |
| for idx in range(len(columns), len(axes)): | |
| axes[idx].set_visible(False) | |
| plt.tight_layout() | |
| # Save | |
| filepath = self._save_figure(fig, filename, "distributions") | |
| self.plot_files['distribution_comparison'] = filepath | |
| return filepath | |
| except Exception as e: | |
| logger.error(f"Error creating distribution comparison: {e}") | |
| return None | |
| def create_time_series_decomposition_plot( | |
| self, | |
| decomposition_result: Dict, | |
| filename: str = "time_series_decomposition" | |
| ) -> str: | |
| """ | |
| Visualise time series decomposition | |
| Parameters: | |
| ----------- | |
| decomposition_result : Dict | |
| Decomposition results | |
| filename : str | |
| Filename for saving | |
| Returns: | |
| -------- | |
| str : path to saved file or None if error | |
| """ | |
| target_col = self.config.target_column | |
| try: | |
| fig, axes = plt.subplots(4, 1, figsize=(14, 10)) | |
| fig.suptitle(f'Time Series Decomposition: {target_col}', | |
| fontsize=16, fontweight='bold', y=0.98) | |
| # Original series | |
| if 'observed' in decomposition_result: | |
| observed = decomposition_result['observed'] | |
| axes[0].plot(observed, color='blue', linewidth=1.5) | |
| axes[0].set_ylabel('Observed', fontsize=11, fontweight='bold') | |
| axes[0].grid(True, alpha=0.3) | |
| axes[0].set_title('Original Time Series', fontsize=12) | |
| # Trend | |
| if 'trend' in decomposition_result and decomposition_result['trend'] is not None: | |
| trend = decomposition_result['trend'] | |
| axes[1].plot(trend, color='red', linewidth=2) | |
| axes[1].set_ylabel('Trend', fontsize=11, fontweight='bold') | |
| axes[1].grid(True, alpha=0.3) | |
| axes[1].set_title('Trend Component', fontsize=12) | |
| # Seasonality | |
| if 'seasonal' in decomposition_result and decomposition_result['seasonal'] is not None: | |
| seasonal = decomposition_result['seasonal'] | |
| axes[2].plot(seasonal, color='green', linewidth=1.5) | |
| axes[2].set_ylabel('Seasonal', fontsize=11, fontweight='bold') | |
| axes[2].grid(True, alpha=0.3) | |
| axes[2].set_title('Seasonal Component', fontsize=12) | |
| # Residuals | |
| if 'residual' in decomposition_result and decomposition_result['residual'] is not None: | |
| residual = decomposition_result['residual'] | |
| axes[3].plot(residual, color='purple', linewidth=1, alpha=0.7) | |
| axes[3].set_ylabel('Residuals', fontsize=11, fontweight='bold') | |
| axes[3].set_xlabel('Date', fontsize=11, fontweight='bold') | |
| axes[3].grid(True, alpha=0.3) | |
| axes[3].set_title('Residual Component', fontsize=12) | |
| # Add residual statistics | |
| if len(residual) > 0: | |
| stats_text = (f"Mean: {residual.mean():.4f}\n" | |
| f"Std: {residual.std():.4f}\n" | |
| f"Min: {residual.min():.4f}\n" | |
| f"Max: {residual.max():.4f}") | |
| axes[3].text(0.02, 0.98, stats_text, transform=axes[3].transAxes, | |
| fontsize=8, verticalalignment='top', | |
| bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5)) | |
| plt.tight_layout() | |
| # Save | |
| filepath = self._save_figure(fig, filename, "time_series") | |
| self.plot_files['time_series_decomposition'] = filepath | |
| return filepath | |
| except Exception as e: | |
| logger.error(f"Error creating time series decomposition: {e}") | |
| return None | |
| def create_data_quality_report( | |
| self, | |
| validation_results: Dict, | |
| filename: str = "data_quality_report" | |
| ) -> str: | |
| """ | |
| Create visual data quality report | |
| Parameters: | |
| ----------- | |
| validation_results : Dict | |
| Validation results | |
| filename : str | |
| Filename for saving | |
| Returns: | |
| -------- | |
| str : path to saved file or None if error | |
| """ | |
| try: | |
| fig = plt.figure(figsize=(16, 12)) | |
| fig.suptitle('Data Quality Report', fontsize=18, fontweight='bold', y=0.98) | |
| # Use GridSpec for more complex layout | |
| gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3) | |
| # 1. Quality radar chart (top left) | |
| ax1 = fig.add_subplot(gs[0, 0], projection='polar') | |
| categories = ['Size', 'Missing', 'Duplicates', 'Stability', 'Informativeness'] | |
| # Extract values from validation results | |
| if 'quality_metrics' in validation_results: | |
| values = [ | |
| validation_results['quality_metrics'].get('size_score', 0.5), | |
| validation_results['quality_metrics'].get('missing_score', 0.5), | |
| validation_results['quality_metrics'].get('duplicates_score', 0.5), | |
| validation_results['quality_metrics'].get('stability_score', 0.5), | |
| validation_results['quality_metrics'].get('informativeness_score', 0.5) | |
| ] | |
| else: | |
| values = [0.8, 0.7, 0.9, 0.6, 0.8] | |
| N = len(categories) | |
| angles = [n / float(N) * 2 * np.pi for n in range(N)] | |
| angles += angles[:1] | |
| values += values[:1] | |
| ax1.plot(angles, values, 'o-', linewidth=2, color='blue') | |
| ax1.fill(angles, values, alpha=0.25, color='blue') | |
| ax1.set_xticks(angles[:-1]) | |
| ax1.set_xticklabels(categories, fontsize=10) | |
| ax1.set_ylim(0, 1) | |
| ax1.set_title('Data Quality Radar Chart', fontsize=12, fontweight='bold') | |
| ax1.grid(True) | |
| # 2. Check status (top right) | |
| ax2 = fig.add_subplot(gs[0, 1]) | |
| basic_checks = validation_results.get('basic_checks', {}) | |
| checks_passed = sum(1 for check in basic_checks.values() if check.get('passed', False)) | |
| checks_total = len(basic_checks) | |
| checks_failed = checks_total - checks_passed | |
| if checks_total > 0: | |
| colors = ['#4CAF50' if checks_passed > 0 else '#FF6B6B', | |
| '#FF6B6B' if checks_failed > 0 else '#4CAF50'] | |
| bars = ax2.bar(['Passed', 'Failed'], | |
| [checks_passed, checks_failed], | |
| color=colors, edgecolor='black') | |
| ax2.set_title(f'Basic Checks: {checks_passed}/{checks_total}', | |
| fontsize=12, fontweight='bold') | |
| ax2.set_ylabel('Number of Checks', fontsize=10) | |
| ax2.grid(True, alpha=0.3, axis='y') | |
| # Add values on bars | |
| for bar, value in zip(bars, [checks_passed, checks_failed]): | |
| height = bar.get_height() | |
| ax2.text(bar.get_x() + bar.get_width()/2., height, | |
| f'{value}', ha='center', va='bottom', fontsize=10, fontweight='bold') | |
| else: | |
| ax2.text(0.5, 0.5, 'No check data available', | |
| ha='center', va='center', transform=ax2.transAxes) | |
| ax2.set_title('Basic Checks', fontsize=12, fontweight='bold') | |
| # 3. Overall score (top right) | |
| ax3 = fig.add_subplot(gs[0, 2]) | |
| overall_score = validation_results.get('overall_score', 0) | |
| status = validation_results.get('status', 'UNKNOWN') | |
| # Score pie chart | |
| sizes = [overall_score, 100 - overall_score] | |
| if overall_score >= 80: | |
| colors = ['#4CAF50', '#E0E0E0'] # Green | |
| elif overall_score >= 60: | |
| colors = ['#FFC107', '#E0E0E0'] # Yellow | |
| else: | |
| colors = ['#F44336', '#E0E0E0'] # Red | |
| wedges, texts, autotexts = ax3.pie(sizes, colors=colors, startangle=90, | |
| autopct='%1.1f%%', pctdistance=0.85) | |
| # Central text | |
| status_colors = {'PASS': '#4CAF50', 'WARNING': '#FFC107', 'FAIL': '#F44336'} | |
| status_color = status_colors.get(status, '#757575') | |
| ax3.text(0, 0, f'{overall_score}/100\n{status}', | |
| ha='center', va='center', fontsize=14, fontweight='bold', | |
| color=status_color) | |
| ax3.set_title('Overall Quality Score', fontsize=12, fontweight='bold') | |
| # 4. Issue distribution by type (left middle) | |
| ax4 = fig.add_subplot(gs[1, 0]) | |
| issues = validation_results.get('issues', {}) | |
| issue_counts = { | |
| 'Critical': len(issues.get('critical', [])), | |
| 'Warnings': len(issues.get('warning', [])), | |
| 'Informational': len(issues.get('info', [])) | |
| } | |
| if any(issue_counts.values()): | |
| colors = ['#F44336', '#FF9800', '#2196F3'] | |
| bars = ax4.bar(issue_counts.keys(), issue_counts.values(), | |
| color=colors, edgecolor='black') | |
| ax4.set_title('Data Issues by Type', fontsize=12, fontweight='bold') | |
| ax4.set_ylabel('Number of Issues', fontsize=10) | |
| ax4.tick_params(axis='x', rotation=45) | |
| ax4.grid(True, alpha=0.3, axis='y') | |
| # Add values on bars | |
| for bar, value in zip(bars, issue_counts.values()): | |
| height = bar.get_height() | |
| ax4.text(bar.get_x() + bar.get_width()/2., height, | |
| f'{value}', ha='center', va='bottom', fontsize=10, fontweight='bold') | |
| else: | |
| ax4.text(0.5, 0.5, 'No issues detected', | |
| ha='center', va='center', transform=ax4.transAxes, fontsize=12) | |
| ax4.set_title('Data Issues', fontsize=12, fontweight='bold') | |
| # 5. Detailed information (remaining cells) | |
| ax5 = fig.add_subplot(gs[1:, 1:]) | |
| ax5.axis('off') | |
| # Form text report | |
| report_text = [] | |
| report_text.append("DETAILED REPORT:") | |
| report_text.append("=" * 40) | |
| # Basic information | |
| report_text.append("\nBASIC INFORMATION:") | |
| report_text.append(f"• Overall score: {overall_score}/100") | |
| report_text.append(f"• Status: {status}") | |
| report_text.append(f"• Checks passed: {checks_passed}/{checks_total}") | |
| # Check details | |
| if basic_checks: | |
| report_text.append("\nCHECK DETAILS:") | |
| for check_name, check_result in basic_checks.items(): | |
| status_icon = "✓" if check_result.get('passed', False) else "✗" | |
| report_text.append(f"• {status_icon} {check_name}: {check_result.get('message', '')}") | |
| # Issues | |
| if any(issue_counts.values()): | |
| report_text.append("\nDETECTED ISSUES:") | |
| if issue_counts['Critical'] > 0: | |
| report_text.append("\nCRITICAL:") | |
| for issue in issues.get('critical', []): | |
| report_text.append(f" • {issue}") | |
| if issue_counts['Warnings'] > 0: | |
| report_text.append("\nWARNINGS:") | |
| for issue in issues.get('warning', []): | |
| report_text.append(f" • {issue}") | |
| if issue_counts['Informational'] > 0: | |
| report_text.append("\nINFORMATIONAL:") | |
| for issue in issues.get('info', []): | |
| report_text.append(f" • {issue}") | |
| # Recommendations | |
| recommendations = validation_results.get('recommendations', []) | |
| if recommendations: | |
| report_text.append("\nRECOMMENDATIONS:") | |
| for i, rec in enumerate(recommendations, 1): | |
| report_text.append(f"{i}. {rec}") | |
| ax5.text(0.02, 0.98, '\n'.join(report_text), transform=ax5.transAxes, | |
| fontsize=9, verticalalignment='top', fontfamily='monospace', | |
| bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.1)) | |
| plt.tight_layout() | |
| # Save | |
| filepath = self._save_figure(fig, filename, "reports") | |
| self.plot_files['data_quality_report'] = filepath | |
| return filepath | |
| except Exception as e: | |
| logger.error(f"Error creating data quality report: {e}") | |
| return None | |
| # ============================================ | |
| # METHODS FOR BATCH SAVING | |
| # ============================================ | |
| def save_all_preprocessing_plots(self) -> Dict[str, str]: | |
| """ | |
| Save all preprocessing plots from current session | |
| Returns: | |
| -------- | |
| Dict[str, str] : dictionary with paths to saved plots | |
| """ | |
| logger.info("Saving all preprocessing plots...") | |
| plots_saved = {} | |
| # Get all open figures | |
| figure_numbers = plt.get_fignums() | |
| if not figure_numbers: | |
| logger.warning("No open plots to save") | |
| return plots_saved | |
| # Save each plot | |
| for fig_num in figure_numbers: | |
| fig = plt.figure(fig_num) | |
| filename = f"preprocessing_plot_{fig_num}.png" | |
| filepath = self._save_figure(fig, filename, "preprocessing") | |
| if filepath: | |
| plots_saved[f"plot_{fig_num}"] = filepath | |
| logger.info(f"Saved {len(plots_saved)} preprocessing plots") | |
| return plots_saved | |
| def create_all_visualizations( | |
| self, | |
| data: pd.DataFrame, | |
| processed_data: pd.DataFrame = None, | |
| feature_importance: Dict = None, | |
| decomposition_result: Dict = None, | |
| validation_results: Dict = None, | |
| preprocessing_stages: Dict = None | |
| ) -> Dict[str, str]: | |
| """ | |
| Create all visualisations in one call | |
| Parameters: | |
| ----------- | |
| data : pd.DataFrame | |
| Original data | |
| processed_data : pd.DataFrame, optional | |
| Processed data | |
| feature_importance : Dict, optional | |
| Feature importance | |
| decomposition_result : Dict, optional | |
| Decomposition results | |
| validation_results : Dict, optional | |
| Validation results | |
| preprocessing_stages : Dict, optional | |
| Preprocessing stages | |
| Returns: | |
| -------- | |
| Dict[str, str] : dictionary with paths to created plots | |
| """ | |
| logger.info("\n" + "="*80) | |
| logger.info("STARTING ALL VISUALISATIONS CREATION") | |
| logger.info("="*80) | |
| result_files = {} | |
| # 1. Summary dashboard | |
| if data is not None: | |
| logger.info("Creating summary dashboard...") | |
| summary_path = self.create_summary_dashboard(data, preprocessing_stages) | |
| if summary_path: | |
| result_files['summary'] = summary_path | |
| # 2. Correlation heatmaps | |
| if data is not None: | |
| logger.info("Creating correlation heatmaps...") | |
| main_corr, target_corr = self.create_correlation_heatmap(data) | |
| if main_corr: | |
| result_files['correlation_main'] = main_corr | |
| if target_corr: | |
| result_files['correlation_target'] = target_corr | |
| # 3. Distribution comparison | |
| if data is not None and processed_data is not None: | |
| logger.info("Creating distribution comparison...") | |
| dist_path = self.create_distribution_comparison(data, processed_data) | |
| if dist_path: | |
| result_files['distribution'] = dist_path | |
| # 4. Feature importance | |
| if feature_importance: | |
| logger.info("Creating feature importance plot...") | |
| feat_path = self.create_feature_importance_plot(feature_importance) | |
| if feat_path: | |
| result_files['feature_importance'] = feat_path | |
| # 5. Time series decomposition | |
| if decomposition_result: | |
| logger.info("Creating time series decomposition...") | |
| decomp_path = self.create_time_series_decomposition_plot(decomposition_result) | |
| if decomp_path: | |
| result_files['decomposition'] = decomp_path | |
| # 6. Data quality report | |
| if validation_results: | |
| logger.info("Creating data quality report...") | |
| quality_path = self.create_data_quality_report(validation_results) | |
| if quality_path: | |
| result_files['quality_report'] = quality_path | |
| # Save information about all plots | |
| self.save_plots_info() | |
| logger.info("\n" + "="*80) | |
| logger.info("VISUALISATIONS SUCCESSFULLY CREATED") | |
| logger.info("="*80) | |
| for plot_name, plot_path in result_files.items(): | |
| if plot_path: | |
| logger.info(f"✓ {plot_name}: {plot_path}") | |
| return result_files | |
| def get_all_plots(self) -> Dict: | |
| """Get information about all created plots""" | |
| return self.plot_files | |
| def save_plots_info(self, filename: str = "plots_info.json") -> None: | |
| """Save plot information to JSON file""" | |
| try: | |
| plots_info = { | |
| 'total_plots': len(self.plot_files), | |
| 'plots': self.plot_files, | |
| 'directories': { | |
| 'correlations': self.correlations_dir, | |
| 'distributions': self.distributions_dir, | |
| 'features': self.features_dir, | |
| 'time_series': self.time_series_dir, | |
| 'preprocessing': self.preprocessing_dir, | |
| 'summary': self.summary_dir, | |
| 'reports': self.reports_dir | |
| }, | |
| 'generation_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), | |
| 'config': { | |
| 'target_column': self.config.target_column, | |
| 'results_dir': self.config.results_dir | |
| } | |
| } | |
| filepath = os.path.join(self.reports_dir, filename) | |
| with open(filepath, 'w', encoding='utf-8') as f: | |
| json.dump(plots_info, f, indent=4, ensure_ascii=False, default=str) | |
| logger.info(f"✓ Plot information saved: {filepath}") | |
| except Exception as e: | |
| logger.error(f"✗ Error saving plot information: {e}") | |
| def move_existing_plots(self, source_dir: str = None) -> Dict[str, str]: | |
| """ | |
| Move existing plots from specified directory to structured folders | |
| Parameters: | |
| ----------- | |
| source_dir : str, optional | |
| Directory with existing plots | |
| Returns: | |
| -------- | |
| Dict[str, str] : dictionary with information about moved files | |
| """ | |
| if source_dir is None: | |
| source_dir = self.plots_dir | |
| if not os.path.exists(source_dir): | |
| logger.warning(f"Source directory doesn't exist: {source_dir}") | |
| return {} | |
| # File to folder mapping | |
| file_to_folder_map = { | |
| # Time series | |
| 'data_split.png': 'time_series', | |
| 'stationarity_raskhodvoda.png': 'time_series', | |
| 'stationarity_analysis.png': 'time_series', | |
| 'temporal_outliers.png': 'time_series', | |
| # Correlations | |
| 'feature_selection_correlation.png': 'correlations', | |
| # Preprocessing | |
| 'missing_values_analysis.png': 'preprocessing', | |
| 'outlier_handling_results.png': 'preprocessing', | |
| 'outliers_analysis.png': 'preprocessing', | |
| 'scaling_results.png': 'preprocessing', | |
| # Default | |
| 'default': 'summary' | |
| } | |
| moved_files = {} | |
| for filename in os.listdir(source_dir): | |
| if filename.endswith('.png'): | |
| source_path = os.path.join(source_dir, filename) | |
| # Determine destination folder | |
| target_folder = file_to_folder_map.get(filename, file_to_folder_map['default']) | |
| target_dir = os.path.join(self.plots_dir, target_folder) | |
| # Create destination folder if doesn't exist | |
| os.makedirs(target_dir, exist_ok=True) | |
| # Target path | |
| target_path = os.path.join(target_dir, filename) | |
| try: | |
| # Move file | |
| os.rename(source_path, target_path) | |
| moved_files[filename] = target_path | |
| logger.info(f"Moved: {filename} -> {target_folder}/") | |
| except Exception as e: | |
| logger.error(f"Error moving {filename}: {e}") | |
| logger.info(f"Moved {len(moved_files)} files") | |
| return moved_files |