# Standard library import os from pathlib import Path # Data handling import pandas as pd import numpy as np # Machine learning from sklearn.model_selection import train_test_split class CheXpertDataSplitter: """ Advanced stratified train-validation splitter for CheXpert dataset. Handles: - Patient-level splitting (prevents data leakage) - Multi-label stratification - Class imbalance awareness - Study-level grouping """ PATHOLOGIES = [ 'No Finding', 'Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity', 'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 'Atelectasis', 'Pneumothorax', 'Pleural Effusion', 'Pleural Other', 'Fracture', 'Support Devices' ] def __init__(self, csv_path, val_size=0.15,test_size=0.15, random_state=42, use_frontal_only=True, fill_uncertain='zeros',root=None): """ Initialize the splitter. Args: csv_path: Path to train.csv from CheXpert-small val_size: Validation set proportion (default: 0.15) random_state: Random seed for reproducibility use_frontal_only: Use only frontal view images fill_uncertain: How to handle uncertain labels ('zeros', 'ones', 'ignore') """ self.csv_path = csv_path self.val_size = val_size self.test_size = test_size self.random_state = random_state self.use_frontal_only = use_frontal_only self.fill_uncertain = fill_uncertain self.root=root print("=" * 80) print("CheXpert Data Splitter - Preventing Data Leakage & Class Bias") print("=" * 80) def load_and_preprocess(self): """Load and preprocess the dataset.""" print("\n[1/5] Loading data...") self.df = pd.read_csv(self.csv_path) print(f" Loaded {len(self.df)} images") #self.df=self.df[self.df["Path"].apply(os.path.exists)] # Filter for frontal views only if self.use_frontal_only: initial_count = len(self.df) self.df = self.df[self.df['Frontal/Lateral'] == 'Frontal'].reset_index(drop=True) print(f" Filtered to frontal views: {len(self.df)} images ({initial_count - len(self.df)} removed)") # Extract patient and study IDs from path print("\n[2/5] Extracting patient and study IDs...") self.df['patient_id'] = self.df['Path'].apply(lambda x: x.split('/')[2]) self.df['study_id'] = self.df['Path'].apply(lambda x: x.split('/')[3]) n_patients = self.df['patient_id'].nunique() n_studies = self.df['study_id'].nunique() print(f" Unique patients: {n_patients}") print(f" Unique studies: {n_studies}") print(f" Images per patient (avg): {len(self.df) / n_patients:.2f}") # Process uncertain labels print("\n[3/5] Processing uncertain labels...") self._process_uncertain_labels() return self.df def _process_uncertain_labels(self): """Process uncertain labels (-1) based on the chosen strategy.""" for pathology in self.PATHOLOGIES: if pathology in self.df.columns: uncertain_count = (self.df[pathology] == -1).sum() if self.fill_uncertain == 'zeros': self.df[pathology] = self.df[pathology].replace(-1, 0) elif self.fill_uncertain == 'ones': self.df[pathology] = self.df[pathology].replace(-1, 1) elif self.fill_uncertain == 'ignore': pass # Keep -1 as is # Fill NaN with 0 self.df[pathology] = self.df[pathology].fillna(0) print(f" Uncertain labels strategy: {self.fill_uncertain}") def create_stratification_groups(self): """ Create stratification groups based on multi-label combinations. Uses patient-level aggregation to prevent data leakage. """ print("\n[4/5] Creating stratification groups (patient-level)...") # Group by patient and aggregate labels patient_groups = self.df.groupby('patient_id').agg({ **{pathology: 'max' for pathology in self.PATHOLOGIES if pathology in self.df.columns}, 'study_id': 'first', # Keep one study_id for reference 'Sex': 'first', 'Age': 'first' }).reset_index() # Create label signature for each patient # This is a binary string representing which conditions are present def create_label_signature(row): signature = [] for pathology in self.PATHOLOGIES: if pathology in patient_groups.columns: signature.append(str(int(row[pathology]))) return ''.join(signature) patient_groups['label_signature'] = patient_groups.apply(create_label_signature, axis=1) # For rare combinations, group them together signature_counts = patient_groups['label_signature'].value_counts() rare_threshold = max(5, int(len(patient_groups) * 0.001)) # At least 5 or 0.1% def get_stratification_group(signature): if signature_counts[signature] < rare_threshold: return 'RARE_COMBINATION' return signature patient_groups['stratification_group'] = patient_groups['label_signature'].apply(get_stratification_group) # Print distribution statistics print(f"\n Patient-level label distribution:") for pathology in self.PATHOLOGIES: if pathology in patient_groups.columns: positive_count = (patient_groups[pathology] == 1).sum() percentage = positive_count / len(patient_groups) * 100 print(f" {pathology:30s}: {positive_count:5d} ({percentage:5.2f}%)") unique_groups = patient_groups['stratification_group'].nunique() print(f"\n Unique stratification groups: {unique_groups}") print(f" Rare combinations grouped: {(patient_groups['stratification_group'] == 'RARE_COMBINATION').sum()}") return patient_groups def perform_split(self, patient_groups): """ Perform stratified train-validation-test split at patient level. """ print("\n[5/5] Performing stratified patient-level split...") stratification_labels = patient_groups['stratification_group'].values # ---- train / (val+test) ---- train_patients, valtest_patients = train_test_split( patient_groups['patient_id'].values, test_size=self.val_size + self.test_size, # <-- new stratify=stratification_labels, random_state=self.random_state ) # ---- val / test from the remaining pool ---- remaining_labels = patient_groups.set_index('patient_id').loc[valtest_patients]['stratification_group'].values val_patients, test_patients = train_test_split( valtest_patients, test_size=self.test_size / (self.val_size + self.test_size), # <-- proportion of the val+test pool stratify=remaining_labels, random_state=self.random_state ) print(f" Train patients: {len(train_patients)}") print(f" Val patients: {len(val_patients)}") print(f" Test patients: {len(test_patients)}") # Split the full dataframe train_df = self.df[self.df['patient_id'].isin(train_patients)].copy() val_df = self.df[self.df['patient_id'].isin(val_patients)].copy() test_df = self.df[self.df['patient_id'].isin(test_patients)].copy() # ---- leakage check (train vs val vs test) ---- sets = [('train', train_df), ('val', val_df), ('test', test_df)] for i, (name_i, df_i) in enumerate(sets): for j, (name_j, df_j) in enumerate(sets[i+1:]): overlap = set(df_i['patient_id']).intersection(set(df_j['patient_id'])) if overlap: raise ValueError(f"Data leakage between {name_i} and {name_j}: {len(overlap)} patients overlap") print("\n No patient overlap – leakage prevented!") return train_df, val_df, test_df def run(self, output_dir='.', save_test=True): self.load_and_preprocess() patient_groups = self.create_stratification_groups() train_df, val_df, test_df = self.perform_split(patient_groups) self.verify_split_quality(train_df, val_df) # optional: also verify train vs test (same function works with two dfs) print("\n--- Train vs Test distribution check ---") self.verify_split_quality(train_df, test_df) train_path, val_path = self.save_splits(train_df, val_df, output_dir) if save_test: test_path = self.save_test_split(test_df, output_dir) else: test_path = None print("\n" + "="*80) print("Split Complete! (train / val / test)") print("="*80) return train_path, val_path, test_path def save_test_split(self, test_df, output_dir): output_dir = Path(output_dir) output_dir.mkdir(exist_ok=True) test_path = output_dir / 'test_ready.csv' cols_to_drop = ['patient_id', 'study_id'] test_clean = test_df.drop(columns=[c for c in cols_to_drop if c in test_df.columns]) test_clean.to_csv(test_path, index=False) print(f"Test set : {test_path} ({len(test_clean)} images)") return test_path def verify_split_quality(self, train_df, val_df): """ Verify the quality of the split by comparing label distributions. """ print("\n" + "=" * 80) print("Split Quality Verification") print("=" * 80) print(f"\n{'Pathology':<30s} {'Train %':>10s} {'Val %':>10s} {'Difference':>12s}") print("-" * 80) max_diff = 0 for pathology in self.PATHOLOGIES: if pathology in train_df.columns: train_pos = (train_df[pathology] == 1).sum() / len(train_df) * 100 val_pos = (val_df[pathology] == 1).sum() / len(val_df) * 100 diff = abs(train_pos - val_pos) max_diff = max(max_diff, diff) print(f"{pathology:<30s} {train_pos:>9.2f}% {val_pos:>9.2f}% {diff:>11.2f}%") print("-" * 80) print(f"Maximum distribution difference: {max_diff:.2f}%") if max_diff < 2.0: print("✓ Excellent stratification (< 2% difference)") elif max_diff < 5.0: print("✓ Good stratification (< 5% difference)") else: print("⚠ Warning: Large distribution differences detected") # Check for class imbalance print("\n" + "=" * 80) print("Class Imbalance Analysis (Train Set)") print("=" * 80) imbalance_ratios = [] for pathology in self.PATHOLOGIES: if pathology in train_df.columns: pos = (train_df[pathology] == 1).sum() neg = (train_df[pathology] == 0).sum() if pos > 0: ratio = neg / pos imbalance_ratios.append(ratio) severity = "Low" if ratio < 5 else "Medium" if ratio < 20 else "High" print(f"{pathology:<30s} Ratio: {ratio:>6.2f}:1 [{severity:>6s} imbalance]") avg_imbalance = np.mean(imbalance_ratios) print(f"\nAverage imbalance ratio: {avg_imbalance:.2f}:1") def save_splits(self, train_df, val_df, output_dir='.'): """Save train and validation splits to CSV files.""" output_dir = Path(output_dir) output_dir.mkdir(exist_ok=True) train_path = output_dir / 'train_ready.csv' val_path = output_dir / 'val_ready.csv' # Remove temporary columns used for splitting columns_to_drop = ['patient_id', 'study_id'] train_df_clean = train_df.drop(columns=[col for col in columns_to_drop if col in train_df.columns]) val_df_clean = val_df.drop(columns=[col for col in columns_to_drop if col in val_df.columns]) train_df_clean.to_csv(train_path, index=False) val_df_clean.to_csv(val_path, index=False) print("\n" + "=" * 80) print("Files Saved Successfully") print("=" * 80) print(f"Train set: {train_path} ({len(train_df_clean)} images)") print(f"Val set: {val_path} ({len(val_df_clean)} images)") return train_path, val_path # Main execution if __name__ == "__main__": root = "/content/drive/MyDrive" # Configuration CHEXPERT_CSV = os.path.join(root,"CheXpert-v1.0-small","train.csv") # Adjust path as needed OUTPUT_DIR = os.path.join(root,"CheXpert-v1.0-small") VAL_SIZE = 0.15 RANDOM_STATE = 42 USE_FRONTAL_ONLY = True FILL_UNCERTAIN = 'zeros' # Options: 'zeros', 'ones', 'ignore' # Create splitter splitter = CheXpertDataSplitter( csv_path=CHEXPERT_CSV, val_size=VAL_SIZE,test_size=VAL_SIZE, random_state=RANDOM_STATE, use_frontal_only=USE_FRONTAL_ONLY, fill_uncertain=FILL_UNCERTAIN, root=OUTPUT_DIR ) # Run the split if os.path.exists(os.path.join(root,"CheXpert-v1.0-small","train_ready.csv")) and os.path.exists(os.path.join(root,"CheXpert-v1.0-small","val_ready.csv")): train_path=os.path.join(root,"CheXpert-v1.0-small","train_ready.csv") val_path=os.path.join(root,"CheXpert-v1.0-small","val_ready.csv") test_path=os.path.join(root,"CheXpert-v1.0-small","test_ready.csv") else: train_path, val_path,test_path = splitter.run(output_dir=OUTPUT_DIR) print("\nYou can now use these files with your CheXpertDataset class:") print(f" train_dataset = CheXpertDataset('{train_path}', root_dir='...', augment=True)") print(f" val_dataset = CheXpertDataset('{val_path}', root_dir='...', augment=False)") print(f" test_dataset = CheXpertDataset('{test_path}', root_dir='...', augment=False)")