| | from pathlib import Path |
| | import numpy as np |
| | from torch.utils.data import Dataset |
| | from PIL import Image |
| | import h5py |
| | import io |
| | import torch.nn.functional as F |
| | import torch |
| |
|
| |
|
| | MAG_DICT = { |
| | "20x": 0, |
| | "10x": 1, |
| | "5x": 2, |
| | "2_5x": 3, |
| | "1_25x": 4, |
| | "0_625x": 5, |
| | "0_3125x": 6, |
| | "0_15625x": 7, |
| | } |
| |
|
| | MAG_NUM_IMGS = { |
| | "20x": 12_509_760, |
| | "10x": 3_036_288, |
| | "5x": 752_000, |
| | "2_5x": 187_280, |
| | "1_25x": 57_090, |
| | "0_625x": 20_679, |
| | "0_3125x": 7_923, |
| | "0_15625x": 2489, |
| | } |
| |
|
| |
|
| | class TCGADataset(Dataset): |
| | def __init__(self, config=None): |
| | self.root = Path(config.get("root")) |
| | self.mag = config.get("mag", None) |
| |
|
| | self.keys = list(MAG_DICT.keys()) |
| | self.feat_target_size = config.get("feat_target_size", -1) |
| | self.return_image = config.get("return_image", False) |
| | self.normalize_ssl = config.get("normalize_ssl", False) |
| |
|
| | def __len__(self): |
| | if self.mag: |
| | return MAG_NUM_IMGS[self.mag] |
| | return MAG_NUM_IMGS["20x"] |
| |
|
| | def __getitem__(self, idx): |
| | if self.mag: |
| | mag_choice = self.mag |
| | else: |
| | mag_choice = np.random.choice(self.keys) |
| | |
| | idx = np.random.randint(0, MAG_NUM_IMGS[mag_choice]) |
| |
|
| | |
| | folder = str(idx // 1_000_000) |
| | folder_path = self.root / f"{mag_choice}/{folder}" |
| |
|
| | try: |
| | vae_feat = np.load(folder_path / f"{idx}_vae.npy").astype(np.float16) |
| | if vae_feat.shape != (3, 64, 64): |
| | |
| | raise Exception(f"vae shape {vae_feat.shape} for idx {idx}") |
| |
|
| | except: |
| | idx = np.random.randint(len(self)) |
| | return self.__getitem__(idx) |
| |
|
| | |
| | ssl_feat = np.load(folder_path / f"{idx}_uni_grid.npy").astype(np.float16) |
| |
|
| | if len(ssl_feat.shape) == 1: |
| | ssl_feat = ssl_feat[:, None] |
| | h = np.sqrt(ssl_feat.shape[1]).astype(int) |
| |
|
| | ssl_feat = torch.tensor(ssl_feat.reshape((-1, h, h))) |
| |
|
| | |
| | if self.feat_target_size != -1 and h > self.feat_target_size: |
| | shape = (self.feat_target_size, self.feat_target_size) |
| | ssl_feat = F.adaptive_avg_pool2d(ssl_feat, shape) |
| |
|
| | |
| | if self.normalize_ssl: |
| | mean = ssl_feat.mean(axis=0, keepdims=True) |
| | std = ssl_feat.std(axis=0, keepdims=True) |
| | ssl_feat = (ssl_feat - mean) / (std + 1e-8) |
| |
|
| |
|
| | |
| | if self.return_image: |
| | image = np.load(folder_path / f"{idx}_img.npy") |
| | image = Image.open(io.BytesIO(image)) |
| | image = np.array(image).astype(np.uint8) |
| |
|
| | else: |
| | image = np.ones((1, 1, 1, 3), dtype=np.float16) |
| |
|
| | return { |
| | "image": image, |
| | "vae_feat": vae_feat, |
| | "ssl_feat": ssl_feat, |
| | "idx": idx, |
| | "mag": MAG_DICT[mag_choice], |
| | } |
| |
|