|
|
|
|
|
import os
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
import pandas as pd
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
from sklearn.model_selection import train_test_split
|
|
|
|
|
|
class CheXpertDataSplitter:
|
|
|
"""
|
|
|
Advanced stratified train-validation splitter for CheXpert dataset.
|
|
|
Handles:
|
|
|
- Patient-level splitting (prevents data leakage)
|
|
|
- Multi-label stratification
|
|
|
- Class imbalance awareness
|
|
|
- Study-level grouping
|
|
|
"""
|
|
|
|
|
|
PATHOLOGIES = [
|
|
|
'No Finding',
|
|
|
'Enlarged Cardiomediastinum',
|
|
|
'Cardiomegaly',
|
|
|
'Lung Opacity',
|
|
|
'Lung Lesion',
|
|
|
'Edema',
|
|
|
'Consolidation',
|
|
|
'Pneumonia',
|
|
|
'Atelectasis',
|
|
|
'Pneumothorax',
|
|
|
'Pleural Effusion',
|
|
|
'Pleural Other',
|
|
|
'Fracture',
|
|
|
'Support Devices'
|
|
|
]
|
|
|
|
|
|
def __init__(self, csv_path, val_size=0.15,test_size=0.15, random_state=42,
|
|
|
use_frontal_only=True, fill_uncertain='zeros',root=None):
|
|
|
"""
|
|
|
Initialize the splitter.
|
|
|
|
|
|
Args:
|
|
|
csv_path: Path to train.csv from CheXpert-small
|
|
|
val_size: Validation set proportion (default: 0.15)
|
|
|
random_state: Random seed for reproducibility
|
|
|
use_frontal_only: Use only frontal view images
|
|
|
fill_uncertain: How to handle uncertain labels ('zeros', 'ones', 'ignore')
|
|
|
"""
|
|
|
self.csv_path = csv_path
|
|
|
self.val_size = val_size
|
|
|
self.test_size = test_size
|
|
|
self.random_state = random_state
|
|
|
self.use_frontal_only = use_frontal_only
|
|
|
self.fill_uncertain = fill_uncertain
|
|
|
self.root=root
|
|
|
|
|
|
print("=" * 80)
|
|
|
print("CheXpert Data Splitter - Preventing Data Leakage & Class Bias")
|
|
|
print("=" * 80)
|
|
|
|
|
|
def load_and_preprocess(self):
|
|
|
"""Load and preprocess the dataset."""
|
|
|
print("\n[1/5] Loading data...")
|
|
|
self.df = pd.read_csv(self.csv_path)
|
|
|
print(f" Loaded {len(self.df)} images")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.use_frontal_only:
|
|
|
initial_count = len(self.df)
|
|
|
self.df = self.df[self.df['Frontal/Lateral'] == 'Frontal'].reset_index(drop=True)
|
|
|
print(f" Filtered to frontal views: {len(self.df)} images ({initial_count - len(self.df)} removed)")
|
|
|
|
|
|
|
|
|
print("\n[2/5] Extracting patient and study IDs...")
|
|
|
self.df['patient_id'] = self.df['Path'].apply(lambda x: x.split('/')[2])
|
|
|
self.df['study_id'] = self.df['Path'].apply(lambda x: x.split('/')[3])
|
|
|
|
|
|
n_patients = self.df['patient_id'].nunique()
|
|
|
n_studies = self.df['study_id'].nunique()
|
|
|
print(f" Unique patients: {n_patients}")
|
|
|
print(f" Unique studies: {n_studies}")
|
|
|
print(f" Images per patient (avg): {len(self.df) / n_patients:.2f}")
|
|
|
|
|
|
|
|
|
print("\n[3/5] Processing uncertain labels...")
|
|
|
self._process_uncertain_labels()
|
|
|
|
|
|
return self.df
|
|
|
|
|
|
def _process_uncertain_labels(self):
|
|
|
"""Process uncertain labels (-1) based on the chosen strategy."""
|
|
|
for pathology in self.PATHOLOGIES:
|
|
|
if pathology in self.df.columns:
|
|
|
uncertain_count = (self.df[pathology] == -1).sum()
|
|
|
|
|
|
if self.fill_uncertain == 'zeros':
|
|
|
self.df[pathology] = self.df[pathology].replace(-1, 0)
|
|
|
elif self.fill_uncertain == 'ones':
|
|
|
self.df[pathology] = self.df[pathology].replace(-1, 1)
|
|
|
elif self.fill_uncertain == 'ignore':
|
|
|
pass
|
|
|
|
|
|
|
|
|
self.df[pathology] = self.df[pathology].fillna(0)
|
|
|
|
|
|
print(f" Uncertain labels strategy: {self.fill_uncertain}")
|
|
|
|
|
|
def create_stratification_groups(self):
|
|
|
"""
|
|
|
Create stratification groups based on multi-label combinations.
|
|
|
Uses patient-level aggregation to prevent data leakage.
|
|
|
"""
|
|
|
print("\n[4/5] Creating stratification groups (patient-level)...")
|
|
|
|
|
|
|
|
|
patient_groups = self.df.groupby('patient_id').agg({
|
|
|
**{pathology: 'max' for pathology in self.PATHOLOGIES if pathology in self.df.columns},
|
|
|
'study_id': 'first',
|
|
|
'Sex': 'first',
|
|
|
'Age': 'first'
|
|
|
}).reset_index()
|
|
|
|
|
|
|
|
|
|
|
|
def create_label_signature(row):
|
|
|
signature = []
|
|
|
for pathology in self.PATHOLOGIES:
|
|
|
if pathology in patient_groups.columns:
|
|
|
signature.append(str(int(row[pathology])))
|
|
|
return ''.join(signature)
|
|
|
|
|
|
patient_groups['label_signature'] = patient_groups.apply(create_label_signature, axis=1)
|
|
|
|
|
|
|
|
|
signature_counts = patient_groups['label_signature'].value_counts()
|
|
|
rare_threshold = max(5, int(len(patient_groups) * 0.001))
|
|
|
|
|
|
def get_stratification_group(signature):
|
|
|
if signature_counts[signature] < rare_threshold:
|
|
|
return 'RARE_COMBINATION'
|
|
|
return signature
|
|
|
|
|
|
patient_groups['stratification_group'] = patient_groups['label_signature'].apply(get_stratification_group)
|
|
|
|
|
|
|
|
|
print(f"\n Patient-level label distribution:")
|
|
|
for pathology in self.PATHOLOGIES:
|
|
|
if pathology in patient_groups.columns:
|
|
|
positive_count = (patient_groups[pathology] == 1).sum()
|
|
|
percentage = positive_count / len(patient_groups) * 100
|
|
|
print(f" {pathology:30s}: {positive_count:5d} ({percentage:5.2f}%)")
|
|
|
|
|
|
unique_groups = patient_groups['stratification_group'].nunique()
|
|
|
print(f"\n Unique stratification groups: {unique_groups}")
|
|
|
print(f" Rare combinations grouped: {(patient_groups['stratification_group'] == 'RARE_COMBINATION').sum()}")
|
|
|
|
|
|
return patient_groups
|
|
|
|
|
|
def perform_split(self, patient_groups):
|
|
|
"""
|
|
|
Perform stratified train-validation-test split at patient level.
|
|
|
"""
|
|
|
print("\n[5/5] Performing stratified patient-level split...")
|
|
|
|
|
|
stratification_labels = patient_groups['stratification_group'].values
|
|
|
|
|
|
|
|
|
train_patients, valtest_patients = train_test_split(
|
|
|
patient_groups['patient_id'].values,
|
|
|
test_size=self.val_size + self.test_size,
|
|
|
stratify=stratification_labels,
|
|
|
random_state=self.random_state
|
|
|
)
|
|
|
|
|
|
|
|
|
remaining_labels = patient_groups.set_index('patient_id').loc[valtest_patients]['stratification_group'].values
|
|
|
val_patients, test_patients = train_test_split(
|
|
|
valtest_patients,
|
|
|
test_size=self.test_size / (self.val_size + self.test_size),
|
|
|
stratify=remaining_labels,
|
|
|
random_state=self.random_state
|
|
|
)
|
|
|
|
|
|
print(f" Train patients: {len(train_patients)}")
|
|
|
print(f" Val patients: {len(val_patients)}")
|
|
|
print(f" Test patients: {len(test_patients)}")
|
|
|
|
|
|
|
|
|
train_df = self.df[self.df['patient_id'].isin(train_patients)].copy()
|
|
|
val_df = self.df[self.df['patient_id'].isin(val_patients)].copy()
|
|
|
test_df = self.df[self.df['patient_id'].isin(test_patients)].copy()
|
|
|
|
|
|
|
|
|
sets = [('train', train_df), ('val', val_df), ('test', test_df)]
|
|
|
for i, (name_i, df_i) in enumerate(sets):
|
|
|
for j, (name_j, df_j) in enumerate(sets[i+1:]):
|
|
|
overlap = set(df_i['patient_id']).intersection(set(df_j['patient_id']))
|
|
|
if overlap:
|
|
|
raise ValueError(f"Data leakage between {name_i} and {name_j}: {len(overlap)} patients overlap")
|
|
|
print("\n No patient overlap – leakage prevented!")
|
|
|
|
|
|
return train_df, val_df, test_df
|
|
|
|
|
|
def run(self, output_dir='.', save_test=True):
|
|
|
self.load_and_preprocess()
|
|
|
patient_groups = self.create_stratification_groups()
|
|
|
train_df, val_df, test_df = self.perform_split(patient_groups)
|
|
|
|
|
|
self.verify_split_quality(train_df, val_df)
|
|
|
|
|
|
print("\n--- Train vs Test distribution check ---")
|
|
|
self.verify_split_quality(train_df, test_df)
|
|
|
|
|
|
train_path, val_path = self.save_splits(train_df, val_df, output_dir)
|
|
|
if save_test:
|
|
|
test_path = self.save_test_split(test_df, output_dir)
|
|
|
else:
|
|
|
test_path = None
|
|
|
|
|
|
print("\n" + "="*80)
|
|
|
print("Split Complete! (train / val / test)")
|
|
|
print("="*80)
|
|
|
return train_path, val_path, test_path
|
|
|
|
|
|
def save_test_split(self, test_df, output_dir):
|
|
|
output_dir = Path(output_dir)
|
|
|
output_dir.mkdir(exist_ok=True)
|
|
|
test_path = output_dir / 'test_ready.csv'
|
|
|
|
|
|
cols_to_drop = ['patient_id', 'study_id']
|
|
|
test_clean = test_df.drop(columns=[c for c in cols_to_drop if c in test_df.columns])
|
|
|
test_clean.to_csv(test_path, index=False)
|
|
|
|
|
|
print(f"Test set : {test_path} ({len(test_clean)} images)")
|
|
|
return test_path
|
|
|
|
|
|
def verify_split_quality(self, train_df, val_df):
|
|
|
"""
|
|
|
Verify the quality of the split by comparing label distributions.
|
|
|
"""
|
|
|
print("\n" + "=" * 80)
|
|
|
print("Split Quality Verification")
|
|
|
print("=" * 80)
|
|
|
|
|
|
print(f"\n{'Pathology':<30s} {'Train %':>10s} {'Val %':>10s} {'Difference':>12s}")
|
|
|
print("-" * 80)
|
|
|
|
|
|
max_diff = 0
|
|
|
for pathology in self.PATHOLOGIES:
|
|
|
if pathology in train_df.columns:
|
|
|
train_pos = (train_df[pathology] == 1).sum() / len(train_df) * 100
|
|
|
val_pos = (val_df[pathology] == 1).sum() / len(val_df) * 100
|
|
|
diff = abs(train_pos - val_pos)
|
|
|
max_diff = max(max_diff, diff)
|
|
|
|
|
|
print(f"{pathology:<30s} {train_pos:>9.2f}% {val_pos:>9.2f}% {diff:>11.2f}%")
|
|
|
|
|
|
print("-" * 80)
|
|
|
print(f"Maximum distribution difference: {max_diff:.2f}%")
|
|
|
|
|
|
if max_diff < 2.0:
|
|
|
print("✓ Excellent stratification (< 2% difference)")
|
|
|
elif max_diff < 5.0:
|
|
|
print("✓ Good stratification (< 5% difference)")
|
|
|
else:
|
|
|
print("⚠ Warning: Large distribution differences detected")
|
|
|
|
|
|
|
|
|
print("\n" + "=" * 80)
|
|
|
print("Class Imbalance Analysis (Train Set)")
|
|
|
print("=" * 80)
|
|
|
|
|
|
imbalance_ratios = []
|
|
|
for pathology in self.PATHOLOGIES:
|
|
|
if pathology in train_df.columns:
|
|
|
pos = (train_df[pathology] == 1).sum()
|
|
|
neg = (train_df[pathology] == 0).sum()
|
|
|
if pos > 0:
|
|
|
ratio = neg / pos
|
|
|
imbalance_ratios.append(ratio)
|
|
|
severity = "Low" if ratio < 5 else "Medium" if ratio < 20 else "High"
|
|
|
print(f"{pathology:<30s} Ratio: {ratio:>6.2f}:1 [{severity:>6s} imbalance]")
|
|
|
|
|
|
avg_imbalance = np.mean(imbalance_ratios)
|
|
|
print(f"\nAverage imbalance ratio: {avg_imbalance:.2f}:1")
|
|
|
|
|
|
def save_splits(self, train_df, val_df, output_dir='.'):
|
|
|
"""Save train and validation splits to CSV files."""
|
|
|
output_dir = Path(output_dir)
|
|
|
output_dir.mkdir(exist_ok=True)
|
|
|
|
|
|
train_path = output_dir / 'train_ready.csv'
|
|
|
val_path = output_dir / 'val_ready.csv'
|
|
|
|
|
|
|
|
|
columns_to_drop = ['patient_id', 'study_id']
|
|
|
train_df_clean = train_df.drop(columns=[col for col in columns_to_drop if col in train_df.columns])
|
|
|
val_df_clean = val_df.drop(columns=[col for col in columns_to_drop if col in val_df.columns])
|
|
|
|
|
|
train_df_clean.to_csv(train_path, index=False)
|
|
|
val_df_clean.to_csv(val_path, index=False)
|
|
|
|
|
|
print("\n" + "=" * 80)
|
|
|
print("Files Saved Successfully")
|
|
|
print("=" * 80)
|
|
|
print(f"Train set: {train_path} ({len(train_df_clean)} images)")
|
|
|
print(f"Val set: {val_path} ({len(val_df_clean)} images)")
|
|
|
|
|
|
return train_path, val_path
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
root = "/content/drive/MyDrive"
|
|
|
|
|
|
CHEXPERT_CSV = os.path.join(root,"CheXpert-v1.0-small","train.csv")
|
|
|
OUTPUT_DIR = os.path.join(root,"CheXpert-v1.0-small")
|
|
|
VAL_SIZE = 0.15
|
|
|
RANDOM_STATE = 42
|
|
|
USE_FRONTAL_ONLY = True
|
|
|
FILL_UNCERTAIN = 'zeros'
|
|
|
|
|
|
|
|
|
splitter = CheXpertDataSplitter(
|
|
|
csv_path=CHEXPERT_CSV,
|
|
|
val_size=VAL_SIZE,test_size=VAL_SIZE,
|
|
|
random_state=RANDOM_STATE,
|
|
|
use_frontal_only=USE_FRONTAL_ONLY,
|
|
|
fill_uncertain=FILL_UNCERTAIN,
|
|
|
root=OUTPUT_DIR
|
|
|
)
|
|
|
|
|
|
|
|
|
if os.path.exists(os.path.join(root,"CheXpert-v1.0-small","train_ready.csv")) and os.path.exists(os.path.join(root,"CheXpert-v1.0-small","val_ready.csv")):
|
|
|
train_path=os.path.join(root,"CheXpert-v1.0-small","train_ready.csv")
|
|
|
val_path=os.path.join(root,"CheXpert-v1.0-small","val_ready.csv")
|
|
|
test_path=os.path.join(root,"CheXpert-v1.0-small","test_ready.csv")
|
|
|
else:
|
|
|
train_path, val_path,test_path = splitter.run(output_dir=OUTPUT_DIR)
|
|
|
|
|
|
print("\nYou can now use these files with your CheXpertDataset class:")
|
|
|
print(f" train_dataset = CheXpertDataset('{train_path}', root_dir='...', augment=True)")
|
|
|
print(f" val_dataset = CheXpertDataset('{val_path}', root_dir='...', augment=False)")
|
|
|
print(f" test_dataset = CheXpertDataset('{test_path}', root_dir='...', augment=False)") |