| import pandas as pd |
| import hashlib |
| import os |
| from collections import defaultdict |
| from pathlib import Path |
|
|
| def text_hash(text): |
| """Create a hash of the text after basic normalization""" |
| |
| text = str(text).strip().lower() |
| |
| text = ' '.join(text.split()) |
| |
| return hashlib.sha256(text.encode()).hexdigest() |
|
|
| def remove_leaked_samples(train_path, val_path, test_path, output_dir='dataset/clean'): |
| """Remove overlapping samples between dataset splits""" |
| print("\n=== Removing Data Leakage ===\n") |
| |
| |
| hash_registry = defaultdict(set) |
| splits = {} |
| original_sizes = {} |
| |
| |
| Path(output_dir).mkdir(parents=True, exist_ok=True) |
| |
| |
| print("Loading datasets...") |
| splits = { |
| 'train': pd.read_csv(train_path), |
| 'val': pd.read_csv(val_path), |
| 'test': pd.read_csv(test_path) |
| } |
| |
| |
| for split_name, df in splits.items(): |
| original_sizes[split_name] = len(df) |
| print(f"Original {split_name} size: {len(df):,} samples") |
| |
| |
| print("\nChecking for overlaps...") |
| removed_counts = defaultdict(int) |
| |
| for split_name, df in splits.items(): |
| print(f"\nProcessing {split_name} split...") |
| |
| |
| current_hashes = set(df['comment_text'].apply(text_hash)) |
| hash_registry[split_name] = current_hashes |
| |
| |
| for other_split in splits: |
| if other_split != split_name: |
| if hash_registry[other_split]: |
| overlaps = current_hashes & hash_registry[other_split] |
| if overlaps: |
| print(f" Found {len(overlaps):,} overlaps with {other_split}") |
| |
| df = df[~df['comment_text'].apply(text_hash).isin(overlaps)] |
| removed_counts[f"{split_name}_from_{other_split}"] = len(overlaps) |
| |
| |
| splits[split_name] = df |
| |
| |
| print("\nSaving cleaned datasets...") |
| for split_name, df in splits.items(): |
| output_path = os.path.join(output_dir, f"{split_name}_clean.csv") |
| df.to_csv(output_path, index=False) |
| reduction = ((original_sizes[split_name] - len(df)) / original_sizes[split_name]) * 100 |
| print(f"Cleaned {split_name}: {len(df):,} samples (-{reduction:.2f}%)") |
| |
| |
| print("\nDetailed Overlap Statistics:") |
| print("-" * 50) |
| for overlap_type, count in removed_counts.items(): |
| split_name, other_split = overlap_type.split('_from_') |
| print(f"{split_name} → {other_split}: {count:,} overlapping samples removed") |
| |
| return splits |
|
|
| def validate_cleaning(splits): |
| """Validate that no overlaps remain between splits""" |
| print("\nValidating Cleaning...") |
| print("-" * 50) |
| |
| all_clean = True |
| for split1 in splits: |
| for split2 in splits: |
| if split1 < split2: |
| hashes1 = set(splits[split1]['comment_text'].apply(text_hash)) |
| hashes2 = set(splits[split2]['comment_text'].apply(text_hash)) |
| overlaps = hashes1 & hashes2 |
| if overlaps: |
| print(f"⚠️ Warning: Found {len(overlaps)} overlaps between {split1} and {split2}") |
| all_clean = False |
| else: |
| print(f"✅ No overlaps between {split1} and {split2}") |
| |
| if all_clean: |
| print("\n✅ All splits are now clean with no overlaps!") |
| else: |
| print("\n⚠️ Some overlaps still remain. Consider additional cleaning.") |
|
|
| if __name__ == "__main__": |
| |
| train_path = "dataset/split/train.csv" |
| val_path = "dataset/split/val.csv" |
| test_path = "dataset/split/test.csv" |
| |
| |
| cleaned_splits = remove_leaked_samples(train_path, val_path, test_path) |
| |
| |
| validate_cleaning(cleaned_splits) |