Spaces:
Running
Running
| # ============================================ | |
| # CLASS 9: DATA SPLITTING | |
| # ============================================ | |
| from datetime import datetime | |
| from typing import Dict, Optional, Tuple | |
| from venv import logger | |
| import pandas as pd | |
| from config.config import Config | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| class DataSplitter: | |
| """Class for splitting data into train, validation and test sets""" | |
| def __init__(self, config: Config): | |
| """ | |
| Initialise data splitter | |
| Parameters: | |
| ----------- | |
| config : Config | |
| Experiment configuration | |
| """ | |
| self.config = config | |
| self.split_info = {} | |
| self.split_indices = {} | |
| self.split_strategy = None | |
| def split( | |
| self, | |
| data: pd.DataFrame, | |
| test_size: Optional[float] = None, | |
| validation_size: Optional[float] = None, | |
| method: str = None, | |
| random_state: int = 42, | |
| **kwargs | |
| ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: | |
| """ | |
| Split data into train, validation and test sets | |
| Parameters: | |
| ----------- | |
| data : pd.DataFrame | |
| Input data | |
| test_size : float, optional | |
| Test set size. If None, uses configuration value. | |
| validation_size : float, optional | |
| Validation set size. If None, uses configuration value. | |
| method : str, optional | |
| Splitting method: 'time', 'random', 'expanding_window', 'sliding_window' | |
| random_state : int | |
| Seed for reproducibility | |
| **kwargs : dict | |
| Additional parameters for method | |
| Returns: | |
| -------- | |
| Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame] | |
| Train, validation and test data | |
| """ | |
| logger.info("\n" + "="*80) | |
| logger.info("DATA SPLITTING") | |
| logger.info("="*80) | |
| test_size = test_size or self.config.test_size | |
| validation_size = validation_size or self.config.validation_size | |
| method = method or self.config.split_method | |
| n = len(data) | |
| logger.info(f"Total data: {n} records") | |
| logger.info(f"Splitting method: {method}") | |
| logger.info(f"Sizes: train={1-test_size-validation_size:.1%}, val={validation_size:.1%}, test={test_size:.1%}") | |
| if method == 'time': | |
| train_data, val_data, test_data = self._time_based_split( | |
| data, test_size, validation_size | |
| ) | |
| elif method == 'random': | |
| train_data, val_data, test_data = self._random_split( | |
| data, test_size, validation_size, random_state | |
| ) | |
| elif method == 'expanding_window': | |
| train_data, val_data, test_data = self._expanding_window_split( | |
| data, test_size, validation_size, **kwargs | |
| ) | |
| elif method == 'sliding_window': | |
| train_data, val_data, test_data = self._sliding_window_split( | |
| data, **kwargs | |
| ) | |
| else: | |
| logger.warning(f"Method {method} not supported, using time-based split") | |
| train_data, val_data, test_data = self._time_based_split( | |
| data, test_size, validation_size | |
| ) | |
| # Save splitting information | |
| self._save_split_info(data, train_data, val_data, test_data, method) | |
| # Output information | |
| self._log_split_summary(train_data, val_data, test_data) | |
| # Visualisation of split | |
| if self.config.save_plots: | |
| self._plot_data_split(data, train_data, val_data, test_data) | |
| return train_data, val_data, test_data | |
| def _time_based_split( | |
| self, | |
| data: pd.DataFrame, | |
| test_size: float, | |
| validation_size: float | |
| ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: | |
| """Time-based splitting preserving temporal order""" | |
| n = len(data) | |
| # Calculate set sizes | |
| test_size_int = int(n * test_size) | |
| val_size_int = int(n * validation_size) | |
| train_size_int = n - test_size_int - val_size_int | |
| # Split data | |
| train_data = data.iloc[:train_size_int].copy() | |
| val_data = data.iloc[train_size_int:train_size_int + val_size_int].copy() | |
| test_data = data.iloc[train_size_int + val_size_int:].copy() | |
| self.split_strategy = 'time_based' | |
| return train_data, val_data, test_data | |
| def _random_split( | |
| self, | |
| data: pd.DataFrame, | |
| test_size: float, | |
| validation_size: float, | |
| random_state: int | |
| ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: | |
| """Random data splitting""" | |
| from sklearn.model_selection import train_test_split | |
| # First split into train+val and test | |
| train_val_data, test_data = train_test_split( | |
| data, | |
| test_size=test_size, | |
| random_state=random_state, | |
| shuffle=True | |
| ) | |
| # Then split train+val into train and val | |
| val_relative_size = validation_size / (1 - test_size) | |
| train_data, val_data = train_test_split( | |
| train_val_data, | |
| test_size=val_relative_size, | |
| random_state=random_state, | |
| shuffle=True | |
| ) | |
| self.split_strategy = 'random' | |
| return train_data, val_data, test_data | |
| def _expanding_window_split( | |
| self, | |
| data: pd.DataFrame, | |
| test_size: float, | |
| validation_size: float, | |
| **kwargs | |
| ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: | |
| """Expanding window split""" | |
| n = len(data) | |
| # Minimum initial window size | |
| initial_window = kwargs.get('initial_window', max(100, int(n * 0.1))) | |
| # Final set sizes | |
| test_size_int = int(n * test_size) | |
| val_size_int = int(n * validation_size) | |
| # Determine boundaries | |
| test_start = n - test_size_int | |
| val_start = test_start - val_size_int | |
| # For expanding window, use all data up to val_start for training | |
| train_data = data.iloc[:val_start].copy() | |
| val_data = data.iloc[val_start:test_start].copy() | |
| test_data = data.iloc[test_start:].copy() | |
| self.split_strategy = 'expanding_window' | |
| return train_data, val_data, test_data | |
| def _sliding_window_split( | |
| self, | |
| data: pd.DataFrame, | |
| **kwargs | |
| ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: | |
| """Sliding window split (for multiple train-val-test pairs)""" | |
| window_size = kwargs.get('window_size', len(data) // 3) | |
| step = kwargs.get('step', window_size // 2) | |
| # For simplicity return single split | |
| # In real scenarios can return list of splits | |
| n = len(data) | |
| train_end = n - window_size | |
| val_end = train_end + window_size // 3 | |
| test_end = n | |
| train_data = data.iloc[:train_end].copy() | |
| val_data = data.iloc[train_end:val_end].copy() | |
| test_data = data.iloc[val_end:].copy() | |
| self.split_strategy = 'sliding_window' | |
| return train_data, val_data, test_data | |
| def _save_split_info( | |
| self, | |
| full_data: pd.DataFrame, | |
| train_data: pd.DataFrame, | |
| val_data: pd.DataFrame, | |
| test_data: pd.DataFrame, | |
| method: str | |
| ) -> None: | |
| """Save splitting information""" | |
| n = len(full_data) | |
| self.split_info = { | |
| 'method': method, | |
| 'strategy': self.split_strategy, | |
| 'train_size': len(train_data), | |
| 'val_size': len(val_data), | |
| 'test_size': len(test_data), | |
| 'train_percent': len(train_data) / n * 100, | |
| 'val_percent': len(val_data) / n * 100, | |
| 'test_percent': len(test_data) / n * 100, | |
| 'total_samples': n, | |
| 'features_count': len(full_data.columns), | |
| 'split_date': datetime.now().strftime('%Y-%m-%d %H:%M:%S') | |
| } | |
| # Add temporal period information if available | |
| if isinstance(full_data.index, pd.DatetimeIndex): | |
| self.split_info.update({ | |
| 'train_period': { | |
| 'start': train_data.index.min().strftime('%Y-%m-%d'), | |
| 'end': train_data.index.max().strftime('%Y-%m-%d') | |
| }, | |
| 'val_period': { | |
| 'start': val_data.index.min().strftime('%Y-%m-%d'), | |
| 'end': val_data.index.max().strftime('%Y-%m-%d') | |
| }, | |
| 'test_period': { | |
| 'start': test_data.index.min().strftime('%Y-%m-%d'), | |
| 'end': test_data.index.max().strftime('%Y-%m-%d') | |
| } | |
| }) | |
| # Save split indices | |
| self.split_indices = { | |
| 'train': train_data.index.tolist(), | |
| 'val': val_data.index.tolist(), | |
| 'test': test_data.index.tolist() | |
| } | |
| def _log_split_summary( | |
| self, | |
| train_data: pd.DataFrame, | |
| val_data: pd.DataFrame, | |
| test_data: pd.DataFrame | |
| ) -> None: | |
| """Log splitting summary""" | |
| logger.info("✓ Data split completed:") | |
| logger.info(f" Train: {len(train_data)} records ({self.split_info['train_percent']:.1f}%)") | |
| logger.info(f" Validation: {len(val_data)} records ({self.split_info['val_percent']:.1f}%)") | |
| logger.info(f" Test: {len(test_data)} records ({self.split_info['test_percent']:.1f}%)") | |
| if 'train_period' in self.split_info: | |
| logger.info(f"\nPeriods:") | |
| logger.info(f" Train: {self.split_info['train_period']['start']} - {self.split_info['train_period']['end']}") | |
| logger.info(f" Validation: {self.split_info['val_period']['start']} - {self.split_info['val_period']['end']}") | |
| logger.info(f" Test: {self.split_info['test_period']['start']} - {self.split_info['test_period']['end']}") | |
| # Target variable statistics | |
| target = self.config.target_column | |
| if target in train_data.columns: | |
| logger.info(f"\nTarget variable '{target}' statistics:") | |
| logger.info(f" Train: mean={train_data[target].mean():.2f}, std={train_data[target].std():.2f}") | |
| logger.info(f" Validation: mean={val_data[target].mean():.2f}, std={val_data[target].std():.2f}") | |
| logger.info(f" Test: mean={test_data[target].mean():.2f}, std={test_data[target].std():.2f}") | |
| def _plot_data_split( | |
| self, | |
| full_data: pd.DataFrame, | |
| train_data: pd.DataFrame, | |
| val_data: pd.DataFrame, | |
| test_data: pd.DataFrame | |
| ) -> None: | |
| """Visualise data splitting""" | |
| fig, axes = plt.subplots(2, 2, figsize=(14, 10)) | |
| target = self.config.target_column | |
| # 1. Time series with set highlighting | |
| if target in full_data.columns and isinstance(full_data.index, pd.DatetimeIndex): | |
| axes[0, 0].plot(train_data.index, train_data[target], | |
| label='Train', colour='blue', alpha=0.7, linewidth=1) | |
| axes[0, 0].plot(val_data.index, val_data[target], | |
| label='Validation', colour='orange', alpha=0.7, linewidth=1) | |
| axes[0, 0].plot(test_data.index, test_data[target], | |
| label='Test', colour='red', alpha=0.7, linewidth=1) | |
| axes[0, 0].set_title(f'Data Split: {target}') | |
| axes[0, 0].set_xlabel('Date') | |
| axes[0, 0].set_ylabel(target) | |
| axes[0, 0].legend() | |
| axes[0, 0].grid(True, alpha=0.3) | |
| # 2. Yearly distribution | |
| if isinstance(full_data.index, pd.DatetimeIndex): | |
| full_data['year'] = full_data.index.year | |
| train_data['year'] = train_data.index.year | |
| val_data['year'] = val_data.index.year | |
| test_data['year'] = test_data.index.year | |
| years = sorted(full_data['year'].unique()) | |
| train_counts = [len(train_data[train_data['year'] == year]) for year in years] | |
| val_counts = [len(val_data[val_data['year'] == year]) for year in years] | |
| test_counts = [len(test_data[test_data['year'] == year]) for year in years] | |
| x = np.arange(len(years)) | |
| width = 0.25 | |
| axes[0, 1].bar(x - width, train_counts, width, label='Train', colour='blue', alpha=0.7) | |
| axes[0, 1].bar(x, val_counts, width, label='Validation', colour='orange', alpha=0.7) | |
| axes[0, 1].bar(x + width, test_counts, width, label='Test', colour='red', alpha=0.7) | |
| axes[0, 1].set_title('Yearly Data Distribution') | |
| axes[0, 1].set_xlabel('Year') | |
| axes[0, 1].set_ylabel('Number of Records') | |
| axes[0, 1].set_xticks(x) | |
| axes[0, 1].set_xticklabels(years, rotation=45) | |
| axes[0, 1].legend() | |
| axes[0, 1].grid(True, alpha=0.3) | |
| # Remove added columns | |
| for df in [full_data, train_data, val_data, test_data]: | |
| if 'year' in df.columns: | |
| df.drop('year', axis=1, inplace=True) | |
| # 3. Target variable distribution | |
| if target in full_data.columns: | |
| axes[1, 0].hist(train_data[target].dropna(), bins=30, alpha=0.5, label='Train', density=True) | |
| axes[1, 0].hist(val_data[target].dropna(), bins=30, alpha=0.5, label='Validation', density=True) | |
| axes[1, 0].hist(test_data[target].dropna(), bins=30, alpha=0.5, label='Test', density=True) | |
| axes[1, 0].set_title(f'{target} Distribution Across Sets') | |
| axes[1, 0].set_xlabel(target) | |
| axes[1, 0].set_ylabel('Density') | |
| axes[1, 0].legend() | |
| axes[1, 0].grid(True, alpha=0.3) | |
| # 4. Set statistics | |
| if target in full_data.columns: | |
| stats_data = [] | |
| for name, df in [('Train', train_data), ('Validation', val_data), ('Test', test_data)]: | |
| if target in df.columns: | |
| stats_data.append({ | |
| 'Dataset': name, | |
| 'Mean': df[target].mean(), | |
| 'Std': df[target].std(), | |
| 'Min': df[target].min(), | |
| 'Max': df[target].max() | |
| }) | |
| if stats_data: | |
| stats_df = pd.DataFrame(stats_data) | |
| stats_table = axes[1, 1].table( | |
| cellText=stats_df.round(2).values, | |
| colLabels=stats_df.columns, | |
| cellLoc='center', | |
| loc='center' | |
| ) | |
| stats_table.auto_set_font_size(False) | |
| stats_table.set_fontsize(9) | |
| stats_table.scale(1, 1.5) | |
| axes[1, 1].axis('off') | |
| axes[1, 1].set_title('Set Statistics') | |
| plt.suptitle(f'Data Splitting: {self.split_info["method"]} method', fontsize=14) | |
| plt.tight_layout() | |
| plt.savefig( | |
| f'{self.config.results_dir}/plots/data_split.png', | |
| dpi=300, | |
| bbox_inches='tight' | |
| ) | |
| plt.show() | |
| def get_report(self) -> Dict: | |
| """Get data splitting report""" | |
| return self.split_info |