import os import torch from openslide import OpenSlide from utils.preprocessor import MacenkoNormalizer, preprocessor from torch.utils.data import Dataset class WSIPatchDataset(Dataset): def __init__( self, coords, wsi_path, pretrained=False, patch_size=256, patch_level=0, macenko=True, return_coord=False, ): self.pretrained = pretrained self.wsi = OpenSlide(wsi_path) self.patch_size = patch_size self.patch_level = patch_level self.return_coord = return_coord if macenko: normalizer = MacenkoNormalizer( target_path=os.path.join( os.path.dirname(os.path.dirname(os.path.join(__file__))), "macenko_target", "macenko_param.pt", ) ) else: normalizer = None self.roi_transforms = preprocessor(pretrained=pretrained, normalizer=normalizer) self.coords = coords self.length = len(self.coords) def __len__(self): return self.length def __getitem__(self, idx): coord = self.coords[idx] img = self.wsi.read_region( coord, self.patch_level, (self.patch_size, self.patch_size) ).convert("RGB") img = self.roi_transforms(img) if self.return_coord: return img, torch.tensor(coord) else: return img