| | import random |
| | import os |
| | import math |
| | import mmcv |
| | import torch |
| | import numpy as np |
| | import torchvision.transforms as T |
| | from torchvision import datasets |
| | from torch.utils.data import Dataset |
| | from megatron.data.autoaugment import ImageNetPolicy |
| | from tasks.vision.segmentation.cityscapes import Cityscapes |
| | import tasks.vision.segmentation.transforms as ET |
| | from megatron.data.autoaugment import ImageNetPolicy |
| | from megatron import get_args |
| | from PIL import Image, ImageOps |
| |
|
| |
|
| | class VitSegmentationJointTransform(): |
| | def __init__(self, train=True, resolution=None): |
| | self.train = train |
| | if self.train: |
| | self.transform0 = ET.RandomSizeAndCrop(resolution) |
| | self.transform1 = ET.RandomHorizontallyFlip() |
| |
|
| | def __call__(self, img, mask): |
| | if self.train: |
| | img, mask = self.transform0(img, mask) |
| | img, mask = self.transform1(img, mask) |
| | return img, mask |
| |
|
| |
|
| | class VitSegmentationImageTransform(): |
| | def __init__(self, train=True, resolution=None): |
| | args = get_args() |
| | self.train = train |
| | assert args.fp16 or args.bf16 |
| | self.data_type = torch.half if args.fp16 else torch.bfloat16 |
| | self.mean_std = args.mean_std |
| | if self.train: |
| | assert resolution is not None |
| | self.transform = T.Compose([ |
| | ET.PhotoMetricDistortion(), |
| | T.ToTensor(), |
| | T.Normalize(*self.mean_std), |
| | T.ConvertImageDtype(self.data_type) |
| | ]) |
| | else: |
| | self.transform = T.Compose([ |
| | T.ToTensor(), |
| | T.Normalize(*self.mean_std), |
| | T.ConvertImageDtype(self.data_type) |
| | ]) |
| |
|
| | def __call__(self, input): |
| | output = self.transform(input) |
| | return output |
| |
|
| |
|
| | class VitSegmentationTargetTransform(): |
| | def __init__(self, train=True, resolution=None): |
| | self.train = train |
| |
|
| | def __call__(self, input): |
| | output = torch.from_numpy(np.array(input, dtype=np.int32)).long() |
| | return output |
| |
|
| |
|
| | class RandomSeedSegmentationDataset(Dataset): |
| | def __init__(self, |
| | dataset, |
| | joint_transform, |
| | image_transform, |
| | target_transform): |
| |
|
| | args = get_args() |
| | self.base_seed = args.seed |
| | self.curr_seed = self.base_seed |
| | self.dataset = dataset |
| | self.joint_transform = joint_transform |
| | self.image_transform = image_transform |
| | self.target_transform = target_transform |
| |
|
| | def __len__(self): |
| | return len(self.dataset) |
| |
|
| | def set_epoch(self, epoch): |
| | self.curr_seed = self.base_seed + 100 * epoch |
| |
|
| | def __getitem__(self, idx): |
| | seed = idx + self.curr_seed |
| | img, mask = self.dataset[idx] |
| |
|
| | torch.manual_seed(seed) |
| | random.seed(seed) |
| | np.random.seed(seed) |
| | img, mask = self.joint_transform(img, mask) |
| | img = self.image_transform(img) |
| | mask = self.target_transform(mask) |
| |
|
| | return img, mask |
| |
|
| |
|
| | def build_cityscapes_train_valid_datasets(data_path, image_size): |
| | args = get_args() |
| | args.num_classes = Cityscapes.num_classes |
| | args.ignore_index = Cityscapes.ignore_index |
| | args.color_table = Cityscapes.color_table |
| | args.mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
| |
|
| | train_joint_transform = \ |
| | VitSegmentationJointTransform(train=True, resolution=image_size) |
| | val_joint_transform = \ |
| | VitSegmentationJointTransform(train=False, resolution=image_size) |
| | train_image_transform = \ |
| | VitSegmentationImageTransform(train=True, resolution=image_size) |
| | val_image_transform = \ |
| | VitSegmentationImageTransform(train=False, resolution=image_size) |
| | train_target_transform = \ |
| | VitSegmentationTargetTransform(train=True, resolution=image_size) |
| | val_target_transform = \ |
| | VitSegmentationTargetTransform(train=False, resolution=image_size) |
| |
|
| | |
| | train_data = Cityscapes( |
| | root=data_path[0], |
| | split='train', |
| | mode='fine', |
| | resolution=image_size |
| | ) |
| | train_data = RandomSeedSegmentationDataset( |
| | train_data, |
| | joint_transform=train_joint_transform, |
| | image_transform=train_image_transform, |
| | target_transform=train_target_transform) |
| |
|
| | |
| | val_data = Cityscapes( |
| | root=data_path[0], |
| | split='val', |
| | mode='fine', |
| | resolution=image_size |
| | ) |
| |
|
| | val_data = RandomSeedSegmentationDataset( |
| | val_data, |
| | joint_transform=val_joint_transform, |
| | image_transform=val_image_transform, |
| | target_transform=val_target_transform) |
| |
|
| | return train_data, val_data |
| |
|
| |
|
| | def build_train_valid_datasets(data_path, image_size): |
| | return build_cityscapes_train_valid_datasets(data_path, image_size) |
| |
|