|
|
import pytorch_lightning as pl |
|
|
from torchvision.datasets import VOCSegmentation |
|
|
from torchvision.transforms import transforms |
|
|
from torch.utils.data import DataLoader, random_split |
|
|
|
|
|
class SegmentationDataModule(pl.LightningDataModule): |
|
|
def __init__(self, data_dir: str, config): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
|
|
|
image_transform = transforms.Compose([ |
|
|
transforms.Resize((128, 128)), |
|
|
transforms.ToTensor() |
|
|
]) |
|
|
mask_transform = transforms.Compose([ |
|
|
transforms.Resize((128, 128)), |
|
|
transforms.ToTensor(), |
|
|
lambda x: x.long() |
|
|
]) |
|
|
self.transform = DualTransform(image_transform, mask_transform) |
|
|
self.data_dir = data_dir |
|
|
|
|
|
def prepare_data(self): |
|
|
|
|
|
VOCSegmentation(root=self.data_dir, year='2012', image_set='trainval', download=False) |
|
|
|
|
|
def setup(self, stage=None): |
|
|
|
|
|
self.dataset = VOCSegmentation(root=self.data_dir, year='2012', image_set='trainval', transforms=self.transform) |
|
|
|
|
|
|
|
|
train_len = int(0.8 * len(self.dataset)) |
|
|
val_len = len(self.dataset) - train_len |
|
|
self.train_dataset, self.val_dataset = random_split(self.dataset, [train_len, val_len]) |
|
|
|
|
|
def train_dataloader(self): |
|
|
return DataLoader(self.train_dataset, batch_size=self.config.batch_size, shuffle=self.config.shuffle, num_workers=self.config.num_workers) |
|
|
|
|
|
def val_dataloader(self): |
|
|
return DataLoader(self.val_dataset, batch_size=self.config.batch_size, shuffle=False, num_workers=self.config.num_workers) |
|
|
|
|
|
|
|
|
|
|
|
class DualTransform: |
|
|
def __init__(self, image_transform, mask_transform): |
|
|
self.image_transform = image_transform |
|
|
self.mask_transform = mask_transform |
|
|
|
|
|
def __call__(self, image, mask): |
|
|
return self.image_transform(image), self.mask_transform(mask) |
|
|
|