Spaces:
Sleeping
Sleeping
File size: 2,596 Bytes
70ea7be | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 | """Split manager for separating DataFrames by train_split column."""
from typing import ClassVar
import pandas as pd
from app.core.exceptions import DatasetError
class SplitManager:
"""Separates DataFrames by train_split column into train/validation/test."""
VALID_SPLITS: ClassVar[set[str]] = {"train", "validation", "test"}
def get_splits(
self, df: pd.DataFrame, split_col: str = "train_split"
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""Return (train_df, val_df, test_df) as new DataFrame copies.
Raises DatasetError if split_col missing or contains invalid values.
"""
if split_col not in df.columns:
raise DatasetError(
f"Column '{split_col}' not found in DataFrame. "
f"Available columns: {list(df.columns)}. "
f"This table may not be a training table."
)
unique_values = set(df[split_col].dropna().unique())
invalid_values = unique_values - self.VALID_SPLITS
if invalid_values:
raise DatasetError(
f"Column '{split_col}' contains invalid values: {sorted(invalid_values)}. "
f"Only {sorted(self.VALID_SPLITS)} are allowed."
)
# Check for null/empty values in the split column
null_count = df[split_col].isna().sum()
if null_count > 0:
raise DatasetError(
f"Column '{split_col}' contains {null_count} null value(s). "
f"All rows must have a valid split assignment."
)
empty_count = (df[split_col] == "").sum()
if empty_count > 0:
raise DatasetError(
f"Column '{split_col}' contains {empty_count} empty string value(s). "
f"All rows must have a valid split assignment from {sorted(self.VALID_SPLITS)}."
)
# Filter by column value only — no random sampling, shuffling, or sklearn splitters
train_df = df[df[split_col] == "train"].copy()
val_df = df[df[split_col] == "validation"].copy()
test_df = df[df[split_col] == "test"].copy()
# Verify row conservation
total_split_rows = len(train_df) + len(val_df) + len(test_df)
if total_split_rows != len(df):
raise DatasetError(
f"Row conservation violated: split rows ({total_split_rows}) "
f"!= original rows ({len(df)}). "
f"This indicates data corruption or unexpected split values."
)
return train_df, val_df, test_df
|