Spaces:
Sleeping
Sleeping
| """Split manager for separating DataFrames by train_split column.""" | |
| from typing import ClassVar | |
| import pandas as pd | |
| from app.core.exceptions import DatasetError | |
| class SplitManager: | |
| """Separates DataFrames by train_split column into train/validation/test.""" | |
| VALID_SPLITS: ClassVar[set[str]] = {"train", "validation", "test"} | |
| def get_splits( | |
| self, df: pd.DataFrame, split_col: str = "train_split" | |
| ) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: | |
| """Return (train_df, val_df, test_df) as new DataFrame copies. | |
| Raises DatasetError if split_col missing or contains invalid values. | |
| """ | |
| if split_col not in df.columns: | |
| raise DatasetError( | |
| f"Column '{split_col}' not found in DataFrame. " | |
| f"Available columns: {list(df.columns)}. " | |
| f"This table may not be a training table." | |
| ) | |
| unique_values = set(df[split_col].dropna().unique()) | |
| invalid_values = unique_values - self.VALID_SPLITS | |
| if invalid_values: | |
| raise DatasetError( | |
| f"Column '{split_col}' contains invalid values: {sorted(invalid_values)}. " | |
| f"Only {sorted(self.VALID_SPLITS)} are allowed." | |
| ) | |
| # Check for null/empty values in the split column | |
| null_count = df[split_col].isna().sum() | |
| if null_count > 0: | |
| raise DatasetError( | |
| f"Column '{split_col}' contains {null_count} null value(s). " | |
| f"All rows must have a valid split assignment." | |
| ) | |
| empty_count = (df[split_col] == "").sum() | |
| if empty_count > 0: | |
| raise DatasetError( | |
| f"Column '{split_col}' contains {empty_count} empty string value(s). " | |
| f"All rows must have a valid split assignment from {sorted(self.VALID_SPLITS)}." | |
| ) | |
| # Filter by column value only — no random sampling, shuffling, or sklearn splitters | |
| train_df = df[df[split_col] == "train"].copy() | |
| val_df = df[df[split_col] == "validation"].copy() | |
| test_df = df[df[split_col] == "test"].copy() | |
| # Verify row conservation | |
| total_split_rows = len(train_df) + len(val_df) + len(test_df) | |
| if total_split_rows != len(df): | |
| raise DatasetError( | |
| f"Row conservation violated: split rows ({total_split_rows}) " | |
| f"!= original rows ({len(df)}). " | |
| f"This indicates data corruption or unexpected split values." | |
| ) | |
| return train_df, val_df, test_df | |