File size: 1,817 Bytes
9d2fc01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
"""Camelyon17 dataset utilities via WILDS."""

from wilds import get_dataset


def load_camelyon17(root_dir="data/wilds", download=False):
    """Load Camelyon17 from WILDS. Returns dataset object with subsets."""
    return get_dataset("camelyon17", download=download, root_dir=root_dir)


def get_camelyon_subsets(root_dir="data/wilds", download=True):
    """
    Load Camelyon17 and return train/val/test subsets.

    Hospital split:
      - Train: H0, H1, H2 (3 hospitals)
      - ID Val: H3 (in-distribution validation, never seen in training)
      - OOD Test: H4 (out-of-distribution test, completely new hospital)

    Returns:
        tuple: (train_subset, id_val_subset, ood_test_subset, dataset)
    """
    ds = load_camelyon17(root_dir=root_dir, download=download)

    train = ds.get_subset("train")  # H0, H1, H2
    id_val = ds.get_subset("id_val")  # H3
    ood_test = ds.get_subset("test")  # H4

    return train, id_val, ood_test, ds


def camelyon_stats(root_dir="data/wilds"):
    """Print dataset statistics."""
    train, id_val, ood_test, ds = get_camelyon_subsets(root_dir)

    print(f"Camelyon17 Dataset Statistics")
    print(f"  Train (H0+H1+H2):     {len(train):,} samples")
    print(f"  ID Val (H3):          {len(id_val):,} samples")
    print(f"  OOD Test (H4):        {len(ood_test):,} samples")
    print(f"  Total:                {len(train) + len(id_val) + len(ood_test):,} samples")
    print(f"  Metadata fields:      {ds.metadata_fields}")

    # Hospital distribution
    train_meta = train.metadata_array
    hospitals = train_meta[:, 0]  # First column is hospital ID
    print(f"\n  Hospital distribution (train):")
    for h in range(3):
        count = (hospitals == h).sum().item()
        print(f"    Hospital {h}: {count:,} samples ({100*count/len(train):.1f}%)")