File size: 5,261 Bytes
891e05c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from functools import partial
import os
import numpy as np
from omegaconf import DictConfig
import pytorch_lightning as pl
import torch
import torchvision.transforms as T
from torch.utils.data import DataLoader

from datasets import load_dataset, concatenate_datasets
from models.loupe import LoupeImageProcessor, LoupeConfig


class DataModule(pl.LightningDataModule):

    def __init__(self, cfg: DictConfig, model_config: LoupeConfig) -> None:
        super().__init__()
        self.cfg = cfg
        self.model_config = model_config
        self.processor = LoupeImageProcessor(self.model_config)

    def setup(self, stage: str) -> None:
        dataset = load_dataset("parquet", data_dir=self.cfg.dataset.data_dir)
        if stage in [None, "validate", "fit"]:
            validset = dataset["validation"]
            if isinstance(self.cfg.dataset.valid_size, int):
                assert 0 < self.cfg.dataset.valid_size < len(validset)
                valid_size = self.cfg.dataset.valid_size
            elif isinstance(self.cfg.dataset.valid_size, float):
                assert 0 < self.cfg.dataset.valid_size <= 1
                valid_size = int(self.cfg.dataset.valid_size * len(validset))
            else:
                raise ValueError(
                    f"Invalid valid_size: {self.cfg.dataset.valid_size}. It should be either int or float."
                )

            # use a small subset to prevent too long validation time
            additional_trainset, validset = validset.train_test_split(
                test_size=valid_size, seed=self.cfg.seed, shuffle=True
            ).values()
            self.validset = validset

            if self.cfg.stage.name in ["cls_seg", "test"] and not getattr(
                self.cfg.stage, "train_on_trainset", False
            ):
                self.trainset = additional_trainset
            else:
                self.trainset = dataset["train"]

        elif stage == "test":
            self.testset = dataset["validation"]
        elif stage == "predict":
            self.testset = dataset["test"]

    def train_collate_fn(self, batch):
        images = [x["image"] for x in batch]
        masks = [x["mask"] for x in batch]
        labels = [x is not None for x in masks]  # mask is None means it is real

        return {
            **self.processor(
                images,
                masks if not getattr(self.cfg.stage, "enable_tta", False) else None,
                self.model_config.enable_patch_cls,
                return_tensors="pt",
            ),
            "labels": torch.tensor(labels, dtype=torch.long),  # (N,)
        }

    def train_dataloader(self):
        return DataLoader(
            self.trainset,
            batch_size=self.cfg.hparams.batch_size,
            num_workers=self.cfg.dataset.num_workers,
            collate_fn=self.train_collate_fn,
            shuffle=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.validset,
            batch_size=self.cfg.hparams.batch_size,
            num_workers=self.cfg.dataset.num_workers,
            collate_fn=self.test_collate_fn,
            shuffle=False,
        )

    def test_collate_fn(self, batch):
        """
        Collate function for valid and test dataloaders.
        Args:
            batch: List of dictionaries containing "image" and "mask" keys.
        """
        images = [x["image"] for x in batch]
        masks = [x["mask"] for x in batch]
        labels = [x is not None for x in masks]  # mask is None means it is real

        outputs = self.processor(images, masks, return_tensors="pt")
        for i, mask in enumerate(masks):
            if mask is None:
                # note that in PIL image, the size is (W, H)
                masks[i] = torch.zeros(
                    (images[i].size[1], images[i].size[0]),
                    dtype=torch.uint8,
                )
            else:
                # convert to binary mask with 0 and 1
                masks[i] = self.processor.convert_to_binary_masks(mask)

        return {
            **outputs,
            "masks": masks,  # a list of (N, H_i, W_i)
            "labels": (torch.tensor(labels, dtype=torch.long)),  # (N,)
        }

    def test_dataloader(self):
        return DataLoader(
            self.testset,
            batch_size=self.cfg.hparams.batch_size,
            num_workers=self.cfg.dataset.num_workers,
            collate_fn=self.test_collate_fn,
        )

    def predict_collate_fn(self, batch):
        """
        Collate function for predict dataloader.
        Args:
            batch: List of dictionaries containing "image" and "mask" keys.
        """
        images = [x["image"] for x in batch]

        outputs = self.processor(images, return_tensors="pt")

        return {
            **outputs,
            "target_sizes": [image.size[::-1] for image in images],
            "name": [os.path.basename(x["path"]) for x in batch],
        }

    def predict_dataloader(self):
        return DataLoader(
            self.testset,
            batch_size=self.cfg.hparams.batch_size,
            num_workers=self.cfg.dataset.num_workers,
            collate_fn=self.predict_collate_fn,
            shuffle=False,
        )