File size: 3,006 Bytes
cd99f4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from pathlib import Path
import numpy as np
from torch.utils.data import Dataset
from PIL import Image
import h5py
import io
import torch.nn.functional as F
import torch


MAG_DICT = {
    "20x": 0,
    "10x": 1,
    "5x": 2,
    "2_5x": 3,
    "1_25x": 4,
    "0_625x": 5,
    "0_3125x": 6,
    "0_15625x": 7,
}

MAG_NUM_IMGS = {
    "20x": 12_509_760,
    "10x": 3_036_288,
    "5x": 752_000,
    "2_5x": 187_280,
    "1_25x": 57_090,
    "0_625x": 20_679,
    "0_3125x": 7_923,
    "0_15625x": 2489,
}


class TCGADataset(Dataset):
    def __init__(self, config=None):
        self.root = Path(config.get("root"))
        self.mag = config.get("mag", None)

        self.keys = list(MAG_DICT.keys())
        self.feat_target_size = config.get("feat_target_size", -1)
        self.return_image = config.get("return_image", False)
        self.normalize_ssl = config.get("normalize_ssl", False)

    def __len__(self):
        if self.mag:
            return MAG_NUM_IMGS[self.mag]
        return MAG_NUM_IMGS["20x"]

    def __getitem__(self, idx):
        if self.mag:
            mag_choice = self.mag
        else:
            mag_choice = np.random.choice(self.keys)
            # pick a random index
            idx = np.random.randint(0, MAG_NUM_IMGS[mag_choice])

        ##### load VAE feat
        folder = str(idx // 1_000_000)
        folder_path = self.root / f"{mag_choice}/{folder}"

        try:
            vae_feat = np.load(folder_path / f"{idx}_vae.npy").astype(np.float16)
            if vae_feat.shape != (3, 64, 64):
                ### TEMPORARY FIX ###
                raise Exception(f"vae shape {vae_feat.shape} for idx {idx}")

        except:
            idx = np.random.randint(len(self))
            return self.__getitem__(idx)

        ###### load SSL feature
        ssl_feat = np.load(folder_path / f"{idx}_uni_grid.npy").astype(np.float16)

        if len(ssl_feat.shape) == 1:
            ssl_feat = ssl_feat[:, None]
        h = np.sqrt(ssl_feat.shape[1]).astype(int)

        ssl_feat = torch.tensor(ssl_feat.reshape((-1, h, h)))

        # resize ssl_feat
        if self.feat_target_size != -1 and h > self.feat_target_size:
            shape = (self.feat_target_size, self.feat_target_size)
            ssl_feat = F.adaptive_avg_pool2d(ssl_feat, shape)

        # normalize ssl_feat
        if self.normalize_ssl:
            mean = ssl_feat.mean(axis=0, keepdims=True)
            std = ssl_feat.std(axis=0, keepdims=True)
            ssl_feat = (ssl_feat - mean) / (std + 1e-8)


        #### load image
        if self.return_image:
            image = np.load(folder_path / f"{idx}_img.npy")
            image = Image.open(io.BytesIO(image))
            image = np.array(image).astype(np.uint8)

        else:
            image = np.ones((1, 1, 1, 3), dtype=np.float16)

        return {
            "image": image,
            "vae_feat": vae_feat,
            "ssl_feat": ssl_feat,
            "idx": idx,
            "mag": MAG_DICT[mag_choice],
        }