# ============================================ # CLASS 9: DATA SPLITTING # ============================================ from datetime import datetime from typing import Dict, Optional, Tuple from venv import logger import pandas as pd from config.config import Config import numpy as np import matplotlib.pyplot as plt class DataSplitter: """Class for splitting data into train, validation and test sets""" def __init__(self, config: Config): """ Initialise data splitter Parameters: ----------- config : Config Experiment configuration """ self.config = config self.split_info = {} self.split_indices = {} self.split_strategy = None def split( self, data: pd.DataFrame, test_size: Optional[float] = None, validation_size: Optional[float] = None, method: str = None, random_state: int = 42, **kwargs ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: """ Split data into train, validation and test sets Parameters: ----------- data : pd.DataFrame Input data test_size : float, optional Test set size. If None, uses configuration value. validation_size : float, optional Validation set size. If None, uses configuration value. method : str, optional Splitting method: 'time', 'random', 'expanding_window', 'sliding_window' random_state : int Seed for reproducibility **kwargs : dict Additional parameters for method Returns: -------- Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame] Train, validation and test data """ logger.info("\n" + "="*80) logger.info("DATA SPLITTING") logger.info("="*80) test_size = test_size or self.config.test_size validation_size = validation_size or self.config.validation_size method = method or self.config.split_method n = len(data) logger.info(f"Total data: {n} records") logger.info(f"Splitting method: {method}") logger.info(f"Sizes: train={1-test_size-validation_size:.1%}, val={validation_size:.1%}, test={test_size:.1%}") if method == 'time': train_data, val_data, test_data = self._time_based_split( data, test_size, validation_size ) elif method == 'random': train_data, val_data, test_data = self._random_split( data, test_size, validation_size, random_state ) elif method == 'expanding_window': train_data, val_data, test_data = self._expanding_window_split( data, test_size, validation_size, **kwargs ) elif method == 'sliding_window': train_data, val_data, test_data = self._sliding_window_split( data, **kwargs ) else: logger.warning(f"Method {method} not supported, using time-based split") train_data, val_data, test_data = self._time_based_split( data, test_size, validation_size ) # Save splitting information self._save_split_info(data, train_data, val_data, test_data, method) # Output information self._log_split_summary(train_data, val_data, test_data) # Visualisation of split if self.config.save_plots: self._plot_data_split(data, train_data, val_data, test_data) return train_data, val_data, test_data def _time_based_split( self, data: pd.DataFrame, test_size: float, validation_size: float ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: """Time-based splitting preserving temporal order""" n = len(data) # Calculate set sizes test_size_int = int(n * test_size) val_size_int = int(n * validation_size) train_size_int = n - test_size_int - val_size_int # Split data train_data = data.iloc[:train_size_int].copy() val_data = data.iloc[train_size_int:train_size_int + val_size_int].copy() test_data = data.iloc[train_size_int + val_size_int:].copy() self.split_strategy = 'time_based' return train_data, val_data, test_data def _random_split( self, data: pd.DataFrame, test_size: float, validation_size: float, random_state: int ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: """Random data splitting""" from sklearn.model_selection import train_test_split # First split into train+val and test train_val_data, test_data = train_test_split( data, test_size=test_size, random_state=random_state, shuffle=True ) # Then split train+val into train and val val_relative_size = validation_size / (1 - test_size) train_data, val_data = train_test_split( train_val_data, test_size=val_relative_size, random_state=random_state, shuffle=True ) self.split_strategy = 'random' return train_data, val_data, test_data def _expanding_window_split( self, data: pd.DataFrame, test_size: float, validation_size: float, **kwargs ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: """Expanding window split""" n = len(data) # Minimum initial window size initial_window = kwargs.get('initial_window', max(100, int(n * 0.1))) # Final set sizes test_size_int = int(n * test_size) val_size_int = int(n * validation_size) # Determine boundaries test_start = n - test_size_int val_start = test_start - val_size_int # For expanding window, use all data up to val_start for training train_data = data.iloc[:val_start].copy() val_data = data.iloc[val_start:test_start].copy() test_data = data.iloc[test_start:].copy() self.split_strategy = 'expanding_window' return train_data, val_data, test_data def _sliding_window_split( self, data: pd.DataFrame, **kwargs ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: """Sliding window split (for multiple train-val-test pairs)""" window_size = kwargs.get('window_size', len(data) // 3) step = kwargs.get('step', window_size // 2) # For simplicity return single split # In real scenarios can return list of splits n = len(data) train_end = n - window_size val_end = train_end + window_size // 3 test_end = n train_data = data.iloc[:train_end].copy() val_data = data.iloc[train_end:val_end].copy() test_data = data.iloc[val_end:].copy() self.split_strategy = 'sliding_window' return train_data, val_data, test_data def _save_split_info( self, full_data: pd.DataFrame, train_data: pd.DataFrame, val_data: pd.DataFrame, test_data: pd.DataFrame, method: str ) -> None: """Save splitting information""" n = len(full_data) self.split_info = { 'method': method, 'strategy': self.split_strategy, 'train_size': len(train_data), 'val_size': len(val_data), 'test_size': len(test_data), 'train_percent': len(train_data) / n * 100, 'val_percent': len(val_data) / n * 100, 'test_percent': len(test_data) / n * 100, 'total_samples': n, 'features_count': len(full_data.columns), 'split_date': datetime.now().strftime('%Y-%m-%d %H:%M:%S') } # Add temporal period information if available if isinstance(full_data.index, pd.DatetimeIndex): self.split_info.update({ 'train_period': { 'start': train_data.index.min().strftime('%Y-%m-%d'), 'end': train_data.index.max().strftime('%Y-%m-%d') }, 'val_period': { 'start': val_data.index.min().strftime('%Y-%m-%d'), 'end': val_data.index.max().strftime('%Y-%m-%d') }, 'test_period': { 'start': test_data.index.min().strftime('%Y-%m-%d'), 'end': test_data.index.max().strftime('%Y-%m-%d') } }) # Save split indices self.split_indices = { 'train': train_data.index.tolist(), 'val': val_data.index.tolist(), 'test': test_data.index.tolist() } def _log_split_summary( self, train_data: pd.DataFrame, val_data: pd.DataFrame, test_data: pd.DataFrame ) -> None: """Log splitting summary""" logger.info("✓ Data split completed:") logger.info(f" Train: {len(train_data)} records ({self.split_info['train_percent']:.1f}%)") logger.info(f" Validation: {len(val_data)} records ({self.split_info['val_percent']:.1f}%)") logger.info(f" Test: {len(test_data)} records ({self.split_info['test_percent']:.1f}%)") if 'train_period' in self.split_info: logger.info(f"\nPeriods:") logger.info(f" Train: {self.split_info['train_period']['start']} - {self.split_info['train_period']['end']}") logger.info(f" Validation: {self.split_info['val_period']['start']} - {self.split_info['val_period']['end']}") logger.info(f" Test: {self.split_info['test_period']['start']} - {self.split_info['test_period']['end']}") # Target variable statistics target = self.config.target_column if target in train_data.columns: logger.info(f"\nTarget variable '{target}' statistics:") logger.info(f" Train: mean={train_data[target].mean():.2f}, std={train_data[target].std():.2f}") logger.info(f" Validation: mean={val_data[target].mean():.2f}, std={val_data[target].std():.2f}") logger.info(f" Test: mean={test_data[target].mean():.2f}, std={test_data[target].std():.2f}") def _plot_data_split( self, full_data: pd.DataFrame, train_data: pd.DataFrame, val_data: pd.DataFrame, test_data: pd.DataFrame ) -> None: """Visualise data splitting""" fig, axes = plt.subplots(2, 2, figsize=(14, 10)) target = self.config.target_column # 1. Time series with set highlighting if target in full_data.columns and isinstance(full_data.index, pd.DatetimeIndex): axes[0, 0].plot(train_data.index, train_data[target], label='Train', colour='blue', alpha=0.7, linewidth=1) axes[0, 0].plot(val_data.index, val_data[target], label='Validation', colour='orange', alpha=0.7, linewidth=1) axes[0, 0].plot(test_data.index, test_data[target], label='Test', colour='red', alpha=0.7, linewidth=1) axes[0, 0].set_title(f'Data Split: {target}') axes[0, 0].set_xlabel('Date') axes[0, 0].set_ylabel(target) axes[0, 0].legend() axes[0, 0].grid(True, alpha=0.3) # 2. Yearly distribution if isinstance(full_data.index, pd.DatetimeIndex): full_data['year'] = full_data.index.year train_data['year'] = train_data.index.year val_data['year'] = val_data.index.year test_data['year'] = test_data.index.year years = sorted(full_data['year'].unique()) train_counts = [len(train_data[train_data['year'] == year]) for year in years] val_counts = [len(val_data[val_data['year'] == year]) for year in years] test_counts = [len(test_data[test_data['year'] == year]) for year in years] x = np.arange(len(years)) width = 0.25 axes[0, 1].bar(x - width, train_counts, width, label='Train', colour='blue', alpha=0.7) axes[0, 1].bar(x, val_counts, width, label='Validation', colour='orange', alpha=0.7) axes[0, 1].bar(x + width, test_counts, width, label='Test', colour='red', alpha=0.7) axes[0, 1].set_title('Yearly Data Distribution') axes[0, 1].set_xlabel('Year') axes[0, 1].set_ylabel('Number of Records') axes[0, 1].set_xticks(x) axes[0, 1].set_xticklabels(years, rotation=45) axes[0, 1].legend() axes[0, 1].grid(True, alpha=0.3) # Remove added columns for df in [full_data, train_data, val_data, test_data]: if 'year' in df.columns: df.drop('year', axis=1, inplace=True) # 3. Target variable distribution if target in full_data.columns: axes[1, 0].hist(train_data[target].dropna(), bins=30, alpha=0.5, label='Train', density=True) axes[1, 0].hist(val_data[target].dropna(), bins=30, alpha=0.5, label='Validation', density=True) axes[1, 0].hist(test_data[target].dropna(), bins=30, alpha=0.5, label='Test', density=True) axes[1, 0].set_title(f'{target} Distribution Across Sets') axes[1, 0].set_xlabel(target) axes[1, 0].set_ylabel('Density') axes[1, 0].legend() axes[1, 0].grid(True, alpha=0.3) # 4. Set statistics if target in full_data.columns: stats_data = [] for name, df in [('Train', train_data), ('Validation', val_data), ('Test', test_data)]: if target in df.columns: stats_data.append({ 'Dataset': name, 'Mean': df[target].mean(), 'Std': df[target].std(), 'Min': df[target].min(), 'Max': df[target].max() }) if stats_data: stats_df = pd.DataFrame(stats_data) stats_table = axes[1, 1].table( cellText=stats_df.round(2).values, colLabels=stats_df.columns, cellLoc='center', loc='center' ) stats_table.auto_set_font_size(False) stats_table.set_fontsize(9) stats_table.scale(1, 1.5) axes[1, 1].axis('off') axes[1, 1].set_title('Set Statistics') plt.suptitle(f'Data Splitting: {self.split_info["method"]} method', fontsize=14) plt.tight_layout() plt.savefig( f'{self.config.results_dir}/plots/data_split.png', dpi=300, bbox_inches='tight' ) plt.show() def get_report(self) -> Dict: """Get data splitting report""" return self.split_info