File size: 2,596 Bytes
70ea7be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""Split manager for separating DataFrames by train_split column."""

from typing import ClassVar

import pandas as pd

from app.core.exceptions import DatasetError


class SplitManager:
    """Separates DataFrames by train_split column into train/validation/test."""

    VALID_SPLITS: ClassVar[set[str]] = {"train", "validation", "test"}

    def get_splits(
        self, df: pd.DataFrame, split_col: str = "train_split"
    ) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
        """Return (train_df, val_df, test_df) as new DataFrame copies.

        Raises DatasetError if split_col missing or contains invalid values.
        """
        if split_col not in df.columns:
            raise DatasetError(
                f"Column '{split_col}' not found in DataFrame. "
                f"Available columns: {list(df.columns)}. "
                f"This table may not be a training table."
            )

        unique_values = set(df[split_col].dropna().unique())
        invalid_values = unique_values - self.VALID_SPLITS

        if invalid_values:
            raise DatasetError(
                f"Column '{split_col}' contains invalid values: {sorted(invalid_values)}. "
                f"Only {sorted(self.VALID_SPLITS)} are allowed."
            )

        # Check for null/empty values in the split column
        null_count = df[split_col].isna().sum()
        if null_count > 0:
            raise DatasetError(
                f"Column '{split_col}' contains {null_count} null value(s). "
                f"All rows must have a valid split assignment."
            )

        empty_count = (df[split_col] == "").sum()
        if empty_count > 0:
            raise DatasetError(
                f"Column '{split_col}' contains {empty_count} empty string value(s). "
                f"All rows must have a valid split assignment from {sorted(self.VALID_SPLITS)}."
            )

        # Filter by column value only — no random sampling, shuffling, or sklearn splitters
        train_df = df[df[split_col] == "train"].copy()
        val_df = df[df[split_col] == "validation"].copy()
        test_df = df[df[split_col] == "test"].copy()

        # Verify row conservation
        total_split_rows = len(train_df) + len(val_df) + len(test_df)
        if total_split_rows != len(df):
            raise DatasetError(
                f"Row conservation violated: split rows ({total_split_rows}) "
                f"!= original rows ({len(df)}). "
                f"This indicates data corruption or unexpected split values."
            )

        return train_df, val_df, test_df