| import os | |
| import torch | |
| from openslide import OpenSlide | |
| from utils.preprocessor import MacenkoNormalizer, preprocessor | |
| from torch.utils.data import Dataset | |
| class WSIPatchDataset(Dataset): | |
| def __init__( | |
| self, | |
| coords, | |
| wsi_path, | |
| pretrained=False, | |
| patch_size=256, | |
| patch_level=0, | |
| macenko=True, | |
| return_coord=False, | |
| ): | |
| self.pretrained = pretrained | |
| self.wsi = OpenSlide(wsi_path) | |
| self.patch_size = patch_size | |
| self.patch_level = patch_level | |
| self.return_coord = return_coord | |
| if macenko: | |
| normalizer = MacenkoNormalizer( | |
| target_path=os.path.join( | |
| os.path.dirname(os.path.dirname(os.path.join(__file__))), | |
| "macenko_target", | |
| "macenko_param.pt", | |
| ) | |
| ) | |
| else: | |
| normalizer = None | |
| self.roi_transforms = preprocessor(pretrained=pretrained, normalizer=normalizer) | |
| self.coords = coords | |
| self.length = len(self.coords) | |
| def __len__(self): | |
| return self.length | |
| def __getitem__(self, idx): | |
| coord = self.coords[idx] | |
| img = self.wsi.read_region( | |
| coord, self.patch_level, (self.patch_size, self.patch_size) | |
| ).convert("RGB") | |
| img = self.roi_transforms(img) | |
| if self.return_coord: | |
| return img, torch.tensor(coord) | |
| else: | |
| return img | |