File size: 1,595 Bytes
9d2fc01 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 | """
utils.pseudo_envs — synthetic environment construction without site labels.
Used to compute the IRMv1 penalty when the dataset has no real
environment metadata. The default heuristic splits images by mean
brightness — a proxy for scanner / acquisition differences which on
real medical data correlates with site identity.
Replace with metadata-driven splitting once a labelled multi-site
dataset (e.g. WILDS-CheXpert) is plugged in.
"""
from __future__ import annotations
from typing import List, Dict
import torch
from torch.utils.data import DataLoader
def make_brightness_envs(dataset, n_envs: int, device: str) -> List[Dict[str, torch.Tensor]]:
"""
Return a list of `n_envs` dicts, each {"x": tensor [N,C,H,W], "y": tensor [N]},
obtained by sorting the dataset by mean per-image brightness and slicing
into equal-sized quantile bins.
"""
all_imgs, all_labels = [], []
for imgs, labels in DataLoader(dataset, batch_size=256, shuffle=False):
all_imgs.append(imgs)
all_labels.append(labels.squeeze().long())
all_imgs = torch.cat(all_imgs)
all_labels = torch.cat(all_labels)
brightness = all_imgs.mean(dim=[1, 2, 3])
sorted_idx = torch.argsort(brightness)
env_size = len(sorted_idx) // n_envs
envs = []
for i in range(n_envs):
start = i * env_size
end = (i + 1) * env_size if i < n_envs - 1 else len(sorted_idx)
idx = sorted_idx[start:end]
envs.append({
"x": all_imgs[idx].to(device),
"y": all_labels[idx].to(device),
})
return envs
|