File size: 3,502 Bytes
f7d11f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import argparse
from pathlib import Path

from loguru import logger
import pandas as pd
from predicting_outcomes_in_heart_failure.config import (
    FEMALE_CSV,
    MALE_CSV,
    NOSEX_CSV,
    PREPROCESSED_CSV,
    PROCESSED_DATA_DIR,
    RANDOM_STATE,
    TARGET_COL,
    TEST_SIZE,
)
from sklearn.model_selection import train_test_split

VARIANTS = {
    "all": PREPROCESSED_CSV,
    "female": FEMALE_CSV,
    "male": MALE_CSV,
    "nosex": NOSEX_CSV,
}


def _safe_train_test_split(X, y, test_size, random_state):
    """Perform a stratified train/test split with fallback if not possible."""
    stratify_y = y if y.nunique() > 1 else None
    try:
        X_tr, X_te, y_tr, y_te = train_test_split(
            X,
            y,
            test_size=test_size,
            stratify=stratify_y,
            random_state=random_state,
            shuffle=True,
        )
        if stratify_y is None:
            logger.warning("Target has only one class — performing non-stratified split.")
        else:
            logger.debug("Stratified split executed successfully.")
        return X_tr, X_te, y_tr, y_te
    except ValueError as e:
        logger.warning(f"Stratified split failed ({e}). Falling back to non-stratified split.")
        return train_test_split(
            X,
            y,
            test_size=test_size,
            stratify=None,
            random_state=random_state,
            shuffle=True,
        )


def split_one(csv_path: Path, variant: str):
    """Split a specific variant (all/female/male/nosex) into train/test sets."""
    if not csv_path.exists():
        logger.warning(f"[{variant}] Missing CSV file: {csv_path} — skipping.")
        return

    df = pd.read_csv(csv_path)
    logger.info(f"[{variant}] Loaded {csv_path} (rows={len(df)}, cols={df.shape[1]})")

    if TARGET_COL not in df.columns:
        raise ValueError(f"[{variant}] Target column '{TARGET_COL}' not found in {csv_path}")

    X = df.drop(columns=[TARGET_COL])
    y = df[TARGET_COL].astype(int)

    X_train, X_test, y_train, y_test = _safe_train_test_split(X, y, TEST_SIZE, RANDOM_STATE)

    train_df = X_train.copy()
    train_df[TARGET_COL] = y_train.values
    test_df = X_test.copy()
    test_df[TARGET_COL] = y_test.values

    out_dir = PROCESSED_DATA_DIR / variant
    out_dir.mkdir(parents=True, exist_ok=True)
    train_p = out_dir / "train.csv"
    test_p = out_dir / "test.csv"

    train_df.to_csv(train_p, index=False)
    test_df.to_csv(test_p, index=False)

    logger.success(f"[{variant}] Saved TRAIN -> {train_p} (rows={len(train_df)})")
    logger.success(f"[{variant}] Saved TEST  -> {test_p} (rows={len(test_df)})")

    train_counts = train_df[TARGET_COL].value_counts().to_dict()
    test_counts = test_df[TARGET_COL].value_counts().to_dict()
    logger.info(f"[{variant}] Class distribution — TRAIN: {train_counts} | TEST: {test_counts}")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--variant",
        type=str,
        choices=list(VARIANTS.keys()),
        required=True,
        help="Data variant to split: all, female, male, or nosex.",
    )
    args = parser.parse_args()

    variant = args.variant
    csv_path = VARIANTS[variant]

    logger.info(f"Starting splitting for variant='{variant}'")
    PROCESSED_DATA_DIR.mkdir(parents=True, exist_ok=True)
    split_one(csv_path, variant)
    logger.success(f"Splitting completed for variant='{variant}'")


if __name__ == "__main__":
    main()