File size: 1,595 Bytes
9d2fc01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
"""
utils.pseudo_envs — synthetic environment construction without site labels.

Used to compute the IRMv1 penalty when the dataset has no real
environment metadata. The default heuristic splits images by mean
brightness — a proxy for scanner / acquisition differences which on
real medical data correlates with site identity.

Replace with metadata-driven splitting once a labelled multi-site
dataset (e.g. WILDS-CheXpert) is plugged in.
"""

from __future__ import annotations

from typing import List, Dict

import torch
from torch.utils.data import DataLoader


def make_brightness_envs(dataset, n_envs: int, device: str) -> List[Dict[str, torch.Tensor]]:
    """
    Return a list of `n_envs` dicts, each {"x": tensor [N,C,H,W], "y": tensor [N]},
    obtained by sorting the dataset by mean per-image brightness and slicing
    into equal-sized quantile bins.
    """
    all_imgs, all_labels = [], []
    for imgs, labels in DataLoader(dataset, batch_size=256, shuffle=False):
        all_imgs.append(imgs)
        all_labels.append(labels.squeeze().long())
    all_imgs   = torch.cat(all_imgs)
    all_labels = torch.cat(all_labels)
    brightness = all_imgs.mean(dim=[1, 2, 3])
    sorted_idx = torch.argsort(brightness)

    env_size = len(sorted_idx) // n_envs
    envs = []
    for i in range(n_envs):
        start = i * env_size
        end   = (i + 1) * env_size if i < n_envs - 1 else len(sorted_idx)
        idx   = sorted_idx[start:end]
        envs.append({
            "x": all_imgs[idx].to(device),
            "y": all_labels[idx].to(device),
        })
    return envs