MuleGuard / src /data /load.py
MuleGuard
MuleGuard: end-to-end mule-account detection + HF Space deploy
af879c2
Raw
History Blame Contribute Delete
1.34 kB
"""Data loading with a fast parquet cache and a leakage-safe stratified split."""
from __future__ import annotations
import pandas as pd
from sklearn.model_selection import train_test_split
from src import config
def load_raw(use_cache: bool = True) -> pd.DataFrame:
"""Load the full dataset. Caches to parquet on first load for speed.
The CSV's leading unnamed column is an index; we drop it.
"""
if use_cache and config.CACHE_PARQUET.exists():
return pd.read_parquet(config.CACHE_PARQUET)
df = pd.read_csv(config.RAW_CSV)
# The first column is an unnamed row index.
first = df.columns[0]
if first.startswith("Unnamed") or first == "":
df = df.drop(columns=[first])
df.to_parquet(config.CACHE_PARQUET)
return df
def split_xy(df: pd.DataFrame):
"""Return (X, y) with the target separated."""
y = df[config.TARGET].astype(int)
X = df.drop(columns=[config.TARGET])
return X, y
def train_test(df: pd.DataFrame):
"""Stratified train/test split. The test set is a true holdout used only
for final evaluation and to feed the live simulator."""
X, y = split_xy(df)
X_tr, X_te, y_tr, y_te = train_test_split(
X, y,
test_size=config.TEST_SIZE,
stratify=y,
random_state=config.SEED,
)
return X_tr, X_te, y_tr, y_te