File size: 2,814 Bytes
006869b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from base import BaseDataLoader
from dataset.datasets import PatchedDataset
from torchvision import transforms
from torch.utils.data.sampler import SequentialSampler


class PatchedDataLoader(BaseDataLoader):
    def __init__(
        self,
        data_dir,
        patch_size,
        batch_size,
        patch_stride=None,
        preds=None,
        target_dist=0.0,
        shuffle=True,
        validation_split=0.0,
        num_workers=1,
        training=True
    ):
        trsfm = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.3551, 0.4698, 0.2261),
                                 (0.1966, 0.1988, 0.1761))
        ])
        target_trsfm = transforms.Compose([
            transforms.ToTensor(),
        ])
        rand_trsfm = transforms.RandomApply([
            transforms.RandomVerticalFlip(),
            transforms.RandomHorizontalFlip()
        ])
        self.data_dir = data_dir
        self.dataset = PatchedDataset(
            self.data_dir,
            patch_size,
            patch_stride=patch_stride,
            preds=preds,
            target_dist=target_dist,
            transform=trsfm,
            target_transform=target_trsfm,
            rand_transform=rand_trsfm if training and shuffle else None,
            train=training,
            late_init=True
        )
        super().__init__(self.dataset, batch_size, shuffle, validation_split, num_workers)

    def _split_sampler(self, split):
        train_sampler, valid_sampler = super()._split_sampler(split)

        if valid_sampler is not None:
            self.dataset.make_dataset(valid_indices=valid_sampler.indices)
        else:
            self.dataset.make_dataset()

        train_idx, valid_idx = [], []
        for patch in self.dataset.patches:
            if valid_sampler is not None and patch.idx in valid_sampler.indices:
                valid_idx.append(self.dataset.patches.index(patch))
            else:
                train_idx.append(self.dataset.patches.index(patch))

        if valid_sampler is not None:
            train_sampler.indices, valid_sampler.indices = train_idx, valid_idx
        else:
            train_sampler = SequentialSampler(train_idx)

        # turn off shuffle option which is mutually exclusive with sampler
        self.shuffle = False
        self.n_samples = len(train_idx)

        return train_sampler, valid_sampler

    def update_dataset(self, preds):
        self.dataset.preds = preds
        self.dataset.patches.clear()
        self.n_samples = len(self.dataset)

        train_sampler, valid_sampler = self._split_sampler(
            self.validation_split)
        if valid_sampler is not None:
            self.valid_sampler.indices = valid_sampler.indices
        self.sampler.indices = train_sampler.indices