BiliSakura's picture
Add files using upload-large-folder tool
cd99f4e verified
from pathlib import Path
import numpy as np
from torch.utils.data import Dataset
from PIL import Image
import h5py
import io
import torch.nn.functional as F
import torch
MAG_DICT = {
"20x": 0,
"10x": 1,
"5x": 2,
"2_5x": 3,
"1_25x": 4,
"0_625x": 5,
"0_3125x": 6,
"0_15625x": 7,
}
MAG_NUM_IMGS = {
"20x": 12_509_760,
"10x": 3_036_288,
"5x": 752_000,
"2_5x": 187_280,
"1_25x": 57_090,
"0_625x": 20_679,
"0_3125x": 7_923,
"0_15625x": 2489,
}
class TCGADataset(Dataset):
def __init__(self, config=None):
self.root = Path(config.get("root"))
self.mag = config.get("mag", None)
self.keys = list(MAG_DICT.keys())
self.feat_target_size = config.get("feat_target_size", -1)
self.return_image = config.get("return_image", False)
self.normalize_ssl = config.get("normalize_ssl", False)
def __len__(self):
if self.mag:
return MAG_NUM_IMGS[self.mag]
return MAG_NUM_IMGS["20x"]
def __getitem__(self, idx):
if self.mag:
mag_choice = self.mag
else:
mag_choice = np.random.choice(self.keys)
# pick a random index
idx = np.random.randint(0, MAG_NUM_IMGS[mag_choice])
##### load VAE feat
folder = str(idx // 1_000_000)
folder_path = self.root / f"{mag_choice}/{folder}"
try:
vae_feat = np.load(folder_path / f"{idx}_vae.npy").astype(np.float16)
if vae_feat.shape != (3, 64, 64):
### TEMPORARY FIX ###
raise Exception(f"vae shape {vae_feat.shape} for idx {idx}")
except:
idx = np.random.randint(len(self))
return self.__getitem__(idx)
###### load SSL feature
ssl_feat = np.load(folder_path / f"{idx}_uni_grid.npy").astype(np.float16)
if len(ssl_feat.shape) == 1:
ssl_feat = ssl_feat[:, None]
h = np.sqrt(ssl_feat.shape[1]).astype(int)
ssl_feat = torch.tensor(ssl_feat.reshape((-1, h, h)))
# resize ssl_feat
if self.feat_target_size != -1 and h > self.feat_target_size:
shape = (self.feat_target_size, self.feat_target_size)
ssl_feat = F.adaptive_avg_pool2d(ssl_feat, shape)
# normalize ssl_feat
if self.normalize_ssl:
mean = ssl_feat.mean(axis=0, keepdims=True)
std = ssl_feat.std(axis=0, keepdims=True)
ssl_feat = (ssl_feat - mean) / (std + 1e-8)
#### load image
if self.return_image:
image = np.load(folder_path / f"{idx}_img.npy")
image = Image.open(io.BytesIO(image))
image = np.array(image).astype(np.uint8)
else:
image = np.ones((1, 1, 1, 3), dtype=np.float16)
return {
"image": image,
"vae_feat": vae_feat,
"ssl_feat": ssl_feat,
"idx": idx,
"mag": MAG_DICT[mag_choice],
}