| """Camelyon17 dataset utilities via WILDS.""" |
|
|
| from wilds import get_dataset |
|
|
|
|
| def load_camelyon17(root_dir="data/wilds", download=False): |
| """Load Camelyon17 from WILDS. Returns dataset object with subsets.""" |
| return get_dataset("camelyon17", download=download, root_dir=root_dir) |
|
|
|
|
| def get_camelyon_subsets(root_dir="data/wilds", download=True): |
| """ |
| Load Camelyon17 and return train/val/test subsets. |
| |
| Hospital split: |
| - Train: H0, H1, H2 (3 hospitals) |
| - ID Val: H3 (in-distribution validation, never seen in training) |
| - OOD Test: H4 (out-of-distribution test, completely new hospital) |
| |
| Returns: |
| tuple: (train_subset, id_val_subset, ood_test_subset, dataset) |
| """ |
| ds = load_camelyon17(root_dir=root_dir, download=download) |
|
|
| train = ds.get_subset("train") |
| id_val = ds.get_subset("id_val") |
| ood_test = ds.get_subset("test") |
|
|
| return train, id_val, ood_test, ds |
|
|
|
|
| def camelyon_stats(root_dir="data/wilds"): |
| """Print dataset statistics.""" |
| train, id_val, ood_test, ds = get_camelyon_subsets(root_dir) |
|
|
| print(f"Camelyon17 Dataset Statistics") |
| print(f" Train (H0+H1+H2): {len(train):,} samples") |
| print(f" ID Val (H3): {len(id_val):,} samples") |
| print(f" OOD Test (H4): {len(ood_test):,} samples") |
| print(f" Total: {len(train) + len(id_val) + len(ood_test):,} samples") |
| print(f" Metadata fields: {ds.metadata_fields}") |
|
|
| |
| train_meta = train.metadata_array |
| hospitals = train_meta[:, 0] |
| print(f"\n Hospital distribution (train):") |
| for h in range(3): |
| count = (hospitals == h).sum().item() |
| print(f" Hospital {h}: {count:,} samples ({100*count/len(train):.1f}%)") |
|
|