"""Split manager for separating DataFrames by train_split column.""" from typing import ClassVar import pandas as pd from app.core.exceptions import DatasetError class SplitManager: """Separates DataFrames by train_split column into train/validation/test.""" VALID_SPLITS: ClassVar[set[str]] = {"train", "validation", "test"} def get_splits( self, df: pd.DataFrame, split_col: str = "train_split" ) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: """Return (train_df, val_df, test_df) as new DataFrame copies. Raises DatasetError if split_col missing or contains invalid values. """ if split_col not in df.columns: raise DatasetError( f"Column '{split_col}' not found in DataFrame. " f"Available columns: {list(df.columns)}. " f"This table may not be a training table." ) unique_values = set(df[split_col].dropna().unique()) invalid_values = unique_values - self.VALID_SPLITS if invalid_values: raise DatasetError( f"Column '{split_col}' contains invalid values: {sorted(invalid_values)}. " f"Only {sorted(self.VALID_SPLITS)} are allowed." ) # Check for null/empty values in the split column null_count = df[split_col].isna().sum() if null_count > 0: raise DatasetError( f"Column '{split_col}' contains {null_count} null value(s). " f"All rows must have a valid split assignment." ) empty_count = (df[split_col] == "").sum() if empty_count > 0: raise DatasetError( f"Column '{split_col}' contains {empty_count} empty string value(s). " f"All rows must have a valid split assignment from {sorted(self.VALID_SPLITS)}." ) # Filter by column value only — no random sampling, shuffling, or sklearn splitters train_df = df[df[split_col] == "train"].copy() val_df = df[df[split_col] == "validation"].copy() test_df = df[df[split_col] == "test"].copy() # Verify row conservation total_split_rows = len(train_df) + len(val_df) + len(test_df) if total_split_rows != len(df): raise DatasetError( f"Row conservation violated: split rows ({total_split_rows}) " f"!= original rows ({len(df)}). " f"This indicates data corruption or unexpected split values." ) return train_df, val_df, test_df