| from typing import Any, Callable, Optional | |
| import torch.nn.functional as F | |
| from base import BaseDataset | |
| from utils.util import TransformMultiple, pil_loader | |
| from dataset.patches import Patches | |
| class PatchedDataset(BaseDataset): | |
| def __init__( | |
| self, | |
| root: str, | |
| patch_size: int, | |
| patch_stride: int = None, | |
| preds: list = None, | |
| target_dist: float = 0.0, | |
| transforms: Optional[Callable] = None, | |
| transform: Optional[Callable] = None, | |
| target_transform: Optional[Callable] = None, | |
| rand_transform: Optional[Callable] = None, | |
| train: bool = True, | |
| late_init: bool = False | |
| ) -> None: | |
| super().__init__(root, pil_loader, transforms, transform, target_transform, train) | |
| self.patches = Patches(patch_size, patch_stride) | |
| self.preds = preds | |
| self.target_dist = target_dist * patch_size ** 2 | |
| self.rand_transform = TransformMultiple(rand_transform) | |
| if not late_init: | |
| self.make_dataset() | |
| def make_dataset(self, valid_indices=[]): | |
| for idx in range(super().__len__()): | |
| _, mask = super().__getitem__(idx) | |
| if self.preds is not None: | |
| mask = self._union_mask(mask, self.preds[idx]) | |
| if idx not in valid_indices: | |
| self.patches.create(idx, mask, cond_fn=self._dist_fn | |
| if self.target_dist != 0.0 else None) | |
| else: | |
| self.patches.create(idx, mask, no_overlap=True) | |
| def __getitem__(self, index: int) -> Any: | |
| patch = self.patches[index] | |
| img, mask = super().__getitem__(patch.idx) | |
| img_patch = self.patches.get_patch(img, patch) | |
| mask_patch = self.patches.get_patch(mask, patch) | |
| img_patch, mask_patch = self.rand_transform( | |
| (img_patch, mask_patch.unsqueeze(dim=0))) | |
| return img_patch, mask_patch.squeeze(dim=0) | |
| def _union_mask(self, mask, pred): | |
| pred = F.pad( | |
| pred, (0, mask.shape[1] - pred.shape[1], 0, mask.shape[0] - pred.shape[0])) | |
| return (mask + pred) - (mask * pred) | |
| def _dist_fn(self, mask, patch): | |
| data = self.patches.get_patch(mask, patch) | |
| return data.count_nonzero() > self.target_dist | |