Martinacap02's picture
Init deploy branch for HF Space
f7d11f7
import argparse
from pathlib import Path
from loguru import logger
import pandas as pd
from predicting_outcomes_in_heart_failure.config import (
FEMALE_CSV,
MALE_CSV,
NOSEX_CSV,
PREPROCESSED_CSV,
PROCESSED_DATA_DIR,
RANDOM_STATE,
TARGET_COL,
TEST_SIZE,
)
from sklearn.model_selection import train_test_split
VARIANTS = {
"all": PREPROCESSED_CSV,
"female": FEMALE_CSV,
"male": MALE_CSV,
"nosex": NOSEX_CSV,
}
def _safe_train_test_split(X, y, test_size, random_state):
"""Perform a stratified train/test split with fallback if not possible."""
stratify_y = y if y.nunique() > 1 else None
try:
X_tr, X_te, y_tr, y_te = train_test_split(
X,
y,
test_size=test_size,
stratify=stratify_y,
random_state=random_state,
shuffle=True,
)
if stratify_y is None:
logger.warning("Target has only one class — performing non-stratified split.")
else:
logger.debug("Stratified split executed successfully.")
return X_tr, X_te, y_tr, y_te
except ValueError as e:
logger.warning(f"Stratified split failed ({e}). Falling back to non-stratified split.")
return train_test_split(
X,
y,
test_size=test_size,
stratify=None,
random_state=random_state,
shuffle=True,
)
def split_one(csv_path: Path, variant: str):
"""Split a specific variant (all/female/male/nosex) into train/test sets."""
if not csv_path.exists():
logger.warning(f"[{variant}] Missing CSV file: {csv_path} — skipping.")
return
df = pd.read_csv(csv_path)
logger.info(f"[{variant}] Loaded {csv_path} (rows={len(df)}, cols={df.shape[1]})")
if TARGET_COL not in df.columns:
raise ValueError(f"[{variant}] Target column '{TARGET_COL}' not found in {csv_path}")
X = df.drop(columns=[TARGET_COL])
y = df[TARGET_COL].astype(int)
X_train, X_test, y_train, y_test = _safe_train_test_split(X, y, TEST_SIZE, RANDOM_STATE)
train_df = X_train.copy()
train_df[TARGET_COL] = y_train.values
test_df = X_test.copy()
test_df[TARGET_COL] = y_test.values
out_dir = PROCESSED_DATA_DIR / variant
out_dir.mkdir(parents=True, exist_ok=True)
train_p = out_dir / "train.csv"
test_p = out_dir / "test.csv"
train_df.to_csv(train_p, index=False)
test_df.to_csv(test_p, index=False)
logger.success(f"[{variant}] Saved TRAIN -> {train_p} (rows={len(train_df)})")
logger.success(f"[{variant}] Saved TEST -> {test_p} (rows={len(test_df)})")
train_counts = train_df[TARGET_COL].value_counts().to_dict()
test_counts = test_df[TARGET_COL].value_counts().to_dict()
logger.info(f"[{variant}] Class distribution — TRAIN: {train_counts} | TEST: {test_counts}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--variant",
type=str,
choices=list(VARIANTS.keys()),
required=True,
help="Data variant to split: all, female, male, or nosex.",
)
args = parser.parse_args()
variant = args.variant
csv_path = VARIANTS[variant]
logger.info(f"Starting splitting for variant='{variant}'")
PROCESSED_DATA_DIR.mkdir(parents=True, exist_ok=True)
split_one(csv_path, variant)
logger.success(f"Splitting completed for variant='{variant}'")
if __name__ == "__main__":
main()