CausalGrok / code /utils /camelyon_data.py
nileshsarkar-ai's picture
Upload code/utils
9d2fc01 verified
"""Camelyon17 dataset utilities via WILDS."""
from wilds import get_dataset
def load_camelyon17(root_dir="data/wilds", download=False):
"""Load Camelyon17 from WILDS. Returns dataset object with subsets."""
return get_dataset("camelyon17", download=download, root_dir=root_dir)
def get_camelyon_subsets(root_dir="data/wilds", download=True):
"""
Load Camelyon17 and return train/val/test subsets.
Hospital split:
- Train: H0, H1, H2 (3 hospitals)
- ID Val: H3 (in-distribution validation, never seen in training)
- OOD Test: H4 (out-of-distribution test, completely new hospital)
Returns:
tuple: (train_subset, id_val_subset, ood_test_subset, dataset)
"""
ds = load_camelyon17(root_dir=root_dir, download=download)
train = ds.get_subset("train") # H0, H1, H2
id_val = ds.get_subset("id_val") # H3
ood_test = ds.get_subset("test") # H4
return train, id_val, ood_test, ds
def camelyon_stats(root_dir="data/wilds"):
"""Print dataset statistics."""
train, id_val, ood_test, ds = get_camelyon_subsets(root_dir)
print(f"Camelyon17 Dataset Statistics")
print(f" Train (H0+H1+H2): {len(train):,} samples")
print(f" ID Val (H3): {len(id_val):,} samples")
print(f" OOD Test (H4): {len(ood_test):,} samples")
print(f" Total: {len(train) + len(id_val) + len(ood_test):,} samples")
print(f" Metadata fields: {ds.metadata_fields}")
# Hospital distribution
train_meta = train.metadata_array
hospitals = train_meta[:, 0] # First column is hospital ID
print(f"\n Hospital distribution (train):")
for h in range(3):
count = (hospitals == h).sum().item()
print(f" Hospital {h}: {count:,} samples ({100*count/len(train):.1f}%)")