| """ |
| utils.pseudo_envs — synthetic environment construction without site labels. |
| |
| Used to compute the IRMv1 penalty when the dataset has no real |
| environment metadata. The default heuristic splits images by mean |
| brightness — a proxy for scanner / acquisition differences which on |
| real medical data correlates with site identity. |
| |
| Replace with metadata-driven splitting once a labelled multi-site |
| dataset (e.g. WILDS-CheXpert) is plugged in. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from typing import List, Dict |
|
|
| import torch |
| from torch.utils.data import DataLoader |
|
|
|
|
| def make_brightness_envs(dataset, n_envs: int, device: str) -> List[Dict[str, torch.Tensor]]: |
| """ |
| Return a list of `n_envs` dicts, each {"x": tensor [N,C,H,W], "y": tensor [N]}, |
| obtained by sorting the dataset by mean per-image brightness and slicing |
| into equal-sized quantile bins. |
| """ |
| all_imgs, all_labels = [], [] |
| for imgs, labels in DataLoader(dataset, batch_size=256, shuffle=False): |
| all_imgs.append(imgs) |
| all_labels.append(labels.squeeze().long()) |
| all_imgs = torch.cat(all_imgs) |
| all_labels = torch.cat(all_labels) |
| brightness = all_imgs.mean(dim=[1, 2, 3]) |
| sorted_idx = torch.argsort(brightness) |
|
|
| env_size = len(sorted_idx) // n_envs |
| envs = [] |
| for i in range(n_envs): |
| start = i * env_size |
| end = (i + 1) * env_size if i < n_envs - 1 else len(sorted_idx) |
| idx = sorted_idx[start:end] |
| envs.append({ |
| "x": all_imgs[idx].to(device), |
| "y": all_labels[idx].to(device), |
| }) |
| return envs |
|
|