mae / data /splitter.py
adelelsayed1991's picture
Upload folder using huggingface_hub
5ffe2e2 verified
# Standard library
import os
from pathlib import Path
# Data handling
import pandas as pd
import numpy as np
# Machine learning
from sklearn.model_selection import train_test_split
class CheXpertDataSplitter:
"""
Advanced stratified train-validation splitter for CheXpert dataset.
Handles:
- Patient-level splitting (prevents data leakage)
- Multi-label stratification
- Class imbalance awareness
- Study-level grouping
"""
PATHOLOGIES = [
'No Finding',
'Enlarged Cardiomediastinum',
'Cardiomegaly',
'Lung Opacity',
'Lung Lesion',
'Edema',
'Consolidation',
'Pneumonia',
'Atelectasis',
'Pneumothorax',
'Pleural Effusion',
'Pleural Other',
'Fracture',
'Support Devices'
]
def __init__(self, csv_path, val_size=0.15,test_size=0.15, random_state=42,
use_frontal_only=True, fill_uncertain='zeros',root=None):
"""
Initialize the splitter.
Args:
csv_path: Path to train.csv from CheXpert-small
val_size: Validation set proportion (default: 0.15)
random_state: Random seed for reproducibility
use_frontal_only: Use only frontal view images
fill_uncertain: How to handle uncertain labels ('zeros', 'ones', 'ignore')
"""
self.csv_path = csv_path
self.val_size = val_size
self.test_size = test_size
self.random_state = random_state
self.use_frontal_only = use_frontal_only
self.fill_uncertain = fill_uncertain
self.root=root
print("=" * 80)
print("CheXpert Data Splitter - Preventing Data Leakage & Class Bias")
print("=" * 80)
def load_and_preprocess(self):
"""Load and preprocess the dataset."""
print("\n[1/5] Loading data...")
self.df = pd.read_csv(self.csv_path)
print(f" Loaded {len(self.df)} images")
#self.df=self.df[self.df["Path"].apply(os.path.exists)]
# Filter for frontal views only
if self.use_frontal_only:
initial_count = len(self.df)
self.df = self.df[self.df['Frontal/Lateral'] == 'Frontal'].reset_index(drop=True)
print(f" Filtered to frontal views: {len(self.df)} images ({initial_count - len(self.df)} removed)")
# Extract patient and study IDs from path
print("\n[2/5] Extracting patient and study IDs...")
self.df['patient_id'] = self.df['Path'].apply(lambda x: x.split('/')[2])
self.df['study_id'] = self.df['Path'].apply(lambda x: x.split('/')[3])
n_patients = self.df['patient_id'].nunique()
n_studies = self.df['study_id'].nunique()
print(f" Unique patients: {n_patients}")
print(f" Unique studies: {n_studies}")
print(f" Images per patient (avg): {len(self.df) / n_patients:.2f}")
# Process uncertain labels
print("\n[3/5] Processing uncertain labels...")
self._process_uncertain_labels()
return self.df
def _process_uncertain_labels(self):
"""Process uncertain labels (-1) based on the chosen strategy."""
for pathology in self.PATHOLOGIES:
if pathology in self.df.columns:
uncertain_count = (self.df[pathology] == -1).sum()
if self.fill_uncertain == 'zeros':
self.df[pathology] = self.df[pathology].replace(-1, 0)
elif self.fill_uncertain == 'ones':
self.df[pathology] = self.df[pathology].replace(-1, 1)
elif self.fill_uncertain == 'ignore':
pass # Keep -1 as is
# Fill NaN with 0
self.df[pathology] = self.df[pathology].fillna(0)
print(f" Uncertain labels strategy: {self.fill_uncertain}")
def create_stratification_groups(self):
"""
Create stratification groups based on multi-label combinations.
Uses patient-level aggregation to prevent data leakage.
"""
print("\n[4/5] Creating stratification groups (patient-level)...")
# Group by patient and aggregate labels
patient_groups = self.df.groupby('patient_id').agg({
**{pathology: 'max' for pathology in self.PATHOLOGIES if pathology in self.df.columns},
'study_id': 'first', # Keep one study_id for reference
'Sex': 'first',
'Age': 'first'
}).reset_index()
# Create label signature for each patient
# This is a binary string representing which conditions are present
def create_label_signature(row):
signature = []
for pathology in self.PATHOLOGIES:
if pathology in patient_groups.columns:
signature.append(str(int(row[pathology])))
return ''.join(signature)
patient_groups['label_signature'] = patient_groups.apply(create_label_signature, axis=1)
# For rare combinations, group them together
signature_counts = patient_groups['label_signature'].value_counts()
rare_threshold = max(5, int(len(patient_groups) * 0.001)) # At least 5 or 0.1%
def get_stratification_group(signature):
if signature_counts[signature] < rare_threshold:
return 'RARE_COMBINATION'
return signature
patient_groups['stratification_group'] = patient_groups['label_signature'].apply(get_stratification_group)
# Print distribution statistics
print(f"\n Patient-level label distribution:")
for pathology in self.PATHOLOGIES:
if pathology in patient_groups.columns:
positive_count = (patient_groups[pathology] == 1).sum()
percentage = positive_count / len(patient_groups) * 100
print(f" {pathology:30s}: {positive_count:5d} ({percentage:5.2f}%)")
unique_groups = patient_groups['stratification_group'].nunique()
print(f"\n Unique stratification groups: {unique_groups}")
print(f" Rare combinations grouped: {(patient_groups['stratification_group'] == 'RARE_COMBINATION').sum()}")
return patient_groups
def perform_split(self, patient_groups):
"""
Perform stratified train-validation-test split at patient level.
"""
print("\n[5/5] Performing stratified patient-level split...")
stratification_labels = patient_groups['stratification_group'].values
# ---- train / (val+test) ----
train_patients, valtest_patients = train_test_split(
patient_groups['patient_id'].values,
test_size=self.val_size + self.test_size, # <-- new
stratify=stratification_labels,
random_state=self.random_state
)
# ---- val / test from the remaining pool ----
remaining_labels = patient_groups.set_index('patient_id').loc[valtest_patients]['stratification_group'].values
val_patients, test_patients = train_test_split(
valtest_patients,
test_size=self.test_size / (self.val_size + self.test_size), # <-- proportion of the val+test pool
stratify=remaining_labels,
random_state=self.random_state
)
print(f" Train patients: {len(train_patients)}")
print(f" Val patients: {len(val_patients)}")
print(f" Test patients: {len(test_patients)}")
# Split the full dataframe
train_df = self.df[self.df['patient_id'].isin(train_patients)].copy()
val_df = self.df[self.df['patient_id'].isin(val_patients)].copy()
test_df = self.df[self.df['patient_id'].isin(test_patients)].copy()
# ---- leakage check (train vs val vs test) ----
sets = [('train', train_df), ('val', val_df), ('test', test_df)]
for i, (name_i, df_i) in enumerate(sets):
for j, (name_j, df_j) in enumerate(sets[i+1:]):
overlap = set(df_i['patient_id']).intersection(set(df_j['patient_id']))
if overlap:
raise ValueError(f"Data leakage between {name_i} and {name_j}: {len(overlap)} patients overlap")
print("\n No patient overlap – leakage prevented!")
return train_df, val_df, test_df
def run(self, output_dir='.', save_test=True):
self.load_and_preprocess()
patient_groups = self.create_stratification_groups()
train_df, val_df, test_df = self.perform_split(patient_groups)
self.verify_split_quality(train_df, val_df)
# optional: also verify train vs test (same function works with two dfs)
print("\n--- Train vs Test distribution check ---")
self.verify_split_quality(train_df, test_df)
train_path, val_path = self.save_splits(train_df, val_df, output_dir)
if save_test:
test_path = self.save_test_split(test_df, output_dir)
else:
test_path = None
print("\n" + "="*80)
print("Split Complete! (train / val / test)")
print("="*80)
return train_path, val_path, test_path
def save_test_split(self, test_df, output_dir):
output_dir = Path(output_dir)
output_dir.mkdir(exist_ok=True)
test_path = output_dir / 'test_ready.csv'
cols_to_drop = ['patient_id', 'study_id']
test_clean = test_df.drop(columns=[c for c in cols_to_drop if c in test_df.columns])
test_clean.to_csv(test_path, index=False)
print(f"Test set : {test_path} ({len(test_clean)} images)")
return test_path
def verify_split_quality(self, train_df, val_df):
"""
Verify the quality of the split by comparing label distributions.
"""
print("\n" + "=" * 80)
print("Split Quality Verification")
print("=" * 80)
print(f"\n{'Pathology':<30s} {'Train %':>10s} {'Val %':>10s} {'Difference':>12s}")
print("-" * 80)
max_diff = 0
for pathology in self.PATHOLOGIES:
if pathology in train_df.columns:
train_pos = (train_df[pathology] == 1).sum() / len(train_df) * 100
val_pos = (val_df[pathology] == 1).sum() / len(val_df) * 100
diff = abs(train_pos - val_pos)
max_diff = max(max_diff, diff)
print(f"{pathology:<30s} {train_pos:>9.2f}% {val_pos:>9.2f}% {diff:>11.2f}%")
print("-" * 80)
print(f"Maximum distribution difference: {max_diff:.2f}%")
if max_diff < 2.0:
print("✓ Excellent stratification (< 2% difference)")
elif max_diff < 5.0:
print("✓ Good stratification (< 5% difference)")
else:
print("⚠ Warning: Large distribution differences detected")
# Check for class imbalance
print("\n" + "=" * 80)
print("Class Imbalance Analysis (Train Set)")
print("=" * 80)
imbalance_ratios = []
for pathology in self.PATHOLOGIES:
if pathology in train_df.columns:
pos = (train_df[pathology] == 1).sum()
neg = (train_df[pathology] == 0).sum()
if pos > 0:
ratio = neg / pos
imbalance_ratios.append(ratio)
severity = "Low" if ratio < 5 else "Medium" if ratio < 20 else "High"
print(f"{pathology:<30s} Ratio: {ratio:>6.2f}:1 [{severity:>6s} imbalance]")
avg_imbalance = np.mean(imbalance_ratios)
print(f"\nAverage imbalance ratio: {avg_imbalance:.2f}:1")
def save_splits(self, train_df, val_df, output_dir='.'):
"""Save train and validation splits to CSV files."""
output_dir = Path(output_dir)
output_dir.mkdir(exist_ok=True)
train_path = output_dir / 'train_ready.csv'
val_path = output_dir / 'val_ready.csv'
# Remove temporary columns used for splitting
columns_to_drop = ['patient_id', 'study_id']
train_df_clean = train_df.drop(columns=[col for col in columns_to_drop if col in train_df.columns])
val_df_clean = val_df.drop(columns=[col for col in columns_to_drop if col in val_df.columns])
train_df_clean.to_csv(train_path, index=False)
val_df_clean.to_csv(val_path, index=False)
print("\n" + "=" * 80)
print("Files Saved Successfully")
print("=" * 80)
print(f"Train set: {train_path} ({len(train_df_clean)} images)")
print(f"Val set: {val_path} ({len(val_df_clean)} images)")
return train_path, val_path
# Main execution
if __name__ == "__main__":
root = "/content/drive/MyDrive"
# Configuration
CHEXPERT_CSV = os.path.join(root,"CheXpert-v1.0-small","train.csv") # Adjust path as needed
OUTPUT_DIR = os.path.join(root,"CheXpert-v1.0-small")
VAL_SIZE = 0.15
RANDOM_STATE = 42
USE_FRONTAL_ONLY = True
FILL_UNCERTAIN = 'zeros' # Options: 'zeros', 'ones', 'ignore'
# Create splitter
splitter = CheXpertDataSplitter(
csv_path=CHEXPERT_CSV,
val_size=VAL_SIZE,test_size=VAL_SIZE,
random_state=RANDOM_STATE,
use_frontal_only=USE_FRONTAL_ONLY,
fill_uncertain=FILL_UNCERTAIN,
root=OUTPUT_DIR
)
# Run the split
if os.path.exists(os.path.join(root,"CheXpert-v1.0-small","train_ready.csv")) and os.path.exists(os.path.join(root,"CheXpert-v1.0-small","val_ready.csv")):
train_path=os.path.join(root,"CheXpert-v1.0-small","train_ready.csv")
val_path=os.path.join(root,"CheXpert-v1.0-small","val_ready.csv")
test_path=os.path.join(root,"CheXpert-v1.0-small","test_ready.csv")
else:
train_path, val_path,test_path = splitter.run(output_dir=OUTPUT_DIR)
print("\nYou can now use these files with your CheXpertDataset class:")
print(f" train_dataset = CheXpertDataset('{train_path}', root_dir='...', augment=True)")
print(f" val_dataset = CheXpertDataset('{val_path}', root_dir='...', augment=False)")
print(f" test_dataset = CheXpertDataset('{test_path}', root_dir='...', augment=False)")