File size: 1,817 Bytes
9d2fc01 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 | """Camelyon17 dataset utilities via WILDS."""
from wilds import get_dataset
def load_camelyon17(root_dir="data/wilds", download=False):
"""Load Camelyon17 from WILDS. Returns dataset object with subsets."""
return get_dataset("camelyon17", download=download, root_dir=root_dir)
def get_camelyon_subsets(root_dir="data/wilds", download=True):
"""
Load Camelyon17 and return train/val/test subsets.
Hospital split:
- Train: H0, H1, H2 (3 hospitals)
- ID Val: H3 (in-distribution validation, never seen in training)
- OOD Test: H4 (out-of-distribution test, completely new hospital)
Returns:
tuple: (train_subset, id_val_subset, ood_test_subset, dataset)
"""
ds = load_camelyon17(root_dir=root_dir, download=download)
train = ds.get_subset("train") # H0, H1, H2
id_val = ds.get_subset("id_val") # H3
ood_test = ds.get_subset("test") # H4
return train, id_val, ood_test, ds
def camelyon_stats(root_dir="data/wilds"):
"""Print dataset statistics."""
train, id_val, ood_test, ds = get_camelyon_subsets(root_dir)
print(f"Camelyon17 Dataset Statistics")
print(f" Train (H0+H1+H2): {len(train):,} samples")
print(f" ID Val (H3): {len(id_val):,} samples")
print(f" OOD Test (H4): {len(ood_test):,} samples")
print(f" Total: {len(train) + len(id_val) + len(ood_test):,} samples")
print(f" Metadata fields: {ds.metadata_fields}")
# Hospital distribution
train_meta = train.metadata_array
hospitals = train_meta[:, 0] # First column is hospital ID
print(f"\n Hospital distribution (train):")
for h in range(3):
count = (hospitals == h).sum().item()
print(f" Hospital {h}: {count:,} samples ({100*count/len(train):.1f}%)")
|