|
|
import os |
|
|
from PIL import Image, ImageOps, ImageFilter |
|
|
import torch |
|
|
import random |
|
|
import numpy as np |
|
|
from torch.utils import data |
|
|
from torchvision import transforms |
|
|
from torchvision.transforms import functional as F |
|
|
import numbers |
|
|
import random |
|
|
import pandas as pd |
|
|
|
|
|
|
|
|
class CoData(data.Dataset): |
|
|
def __init__(self, img_root, gt_root, img_size, transform, max_num, is_train): |
|
|
|
|
|
class_list = os.listdir(img_root) |
|
|
self.size = [img_size, img_size] |
|
|
self.img_dirs = list( |
|
|
map(lambda x: os.path.join(img_root, x), class_list)) |
|
|
self.gt_dirs = list( |
|
|
map(lambda x: os.path.join(gt_root, x), class_list)) |
|
|
self.transform = transform |
|
|
self.max_num = max_num |
|
|
self.is_train = is_train |
|
|
|
|
|
def __getitem__(self, item): |
|
|
names = os.listdir(self.img_dirs[item]) |
|
|
num = len(names) |
|
|
img_paths = list( |
|
|
map(lambda x: os.path.join(self.img_dirs[item], x), names)) |
|
|
gt_paths = list( |
|
|
map(lambda x: os.path.join(self.gt_dirs[item], x[:-4]+'.png'), names)) |
|
|
|
|
|
if self.is_train: |
|
|
final_num = min(num, self.max_num) |
|
|
|
|
|
sampled_list = random.sample(range(num), final_num) |
|
|
|
|
|
new_img_paths = [img_paths[i] for i in sampled_list] |
|
|
img_paths = new_img_paths |
|
|
new_gt_paths = [gt_paths[i] for i in sampled_list] |
|
|
gt_paths = new_gt_paths |
|
|
|
|
|
final_num = final_num |
|
|
else: |
|
|
final_num = num |
|
|
|
|
|
imgs = torch.Tensor(final_num, 3, self.size[0], self.size[1]) |
|
|
gts = torch.Tensor(final_num, 1, self.size[0], self.size[1]) |
|
|
|
|
|
subpaths = [] |
|
|
ori_sizes = [] |
|
|
for idx in range(final_num): |
|
|
|
|
|
img = Image.open(img_paths[idx]).convert('RGB') |
|
|
gt = Image.open(gt_paths[idx]).convert('L') |
|
|
|
|
|
subpaths.append(os.path.join(img_paths[idx].split('/')[-2], img_paths[idx].split('/')[-1][:-4]+'.png')) |
|
|
ori_sizes.append((img.size[1], img.size[0])) |
|
|
|
|
|
|
|
|
[img, gt] = self.transform(img, gt) |
|
|
|
|
|
imgs[idx] = img |
|
|
gts[idx] = gt |
|
|
if self.is_train: |
|
|
cls_ls = [item] * int(final_num) |
|
|
return imgs, gts, subpaths, ori_sizes, cls_ls |
|
|
else: |
|
|
return imgs, gts, subpaths, ori_sizes |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.img_dirs) |
|
|
|
|
|
|
|
|
class FixedResize(object): |
|
|
def __init__(self, size): |
|
|
self.size = (size, size) |
|
|
|
|
|
def __call__(self, img, gt): |
|
|
|
|
|
|
|
|
img = img.resize(self.size, Image.BILINEAR) |
|
|
gt = gt.resize(self.size, Image.NEAREST) |
|
|
|
|
|
|
|
|
return img, gt |
|
|
|
|
|
|
|
|
class ToTensor(object): |
|
|
def __call__(self, img, gt): |
|
|
|
|
|
return F.to_tensor(img), F.to_tensor(gt) |
|
|
|
|
|
|
|
|
class Normalize(object): |
|
|
"""Normalize a tensor image with mean and standard deviation. |
|
|
Args: |
|
|
mean (tuple): means for each channel. |
|
|
std (tuple): standard deviations for each channel. |
|
|
""" |
|
|
|
|
|
def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)): |
|
|
self.mean = mean |
|
|
self.std = std |
|
|
|
|
|
def __call__(self, img, gt): |
|
|
img = F.normalize(img, self.mean, self.std) |
|
|
|
|
|
return img, gt |
|
|
|
|
|
|
|
|
class RandomHorizontalFlip(object): |
|
|
def __init__(self, p=0.5): |
|
|
self.p = p |
|
|
|
|
|
def __call__(self, img, gt): |
|
|
if random.random() < self.p: |
|
|
img = img.transpose(Image.FLIP_LEFT_RIGHT) |
|
|
gt = gt.transpose(Image.FLIP_LEFT_RIGHT) |
|
|
|
|
|
return img, gt |
|
|
|
|
|
|
|
|
class RandomScaleCrop(object): |
|
|
def __init__(self, base_size, crop_size, fill=0): |
|
|
self.base_size = base_size |
|
|
self.crop_size = crop_size |
|
|
self.fill = fill |
|
|
|
|
|
def __call__(self, img, mask): |
|
|
|
|
|
|
|
|
|
|
|
short_size = random.randint(int(self.base_size * 0.8), int(self.base_size * 1.2)) |
|
|
w, h = img.size |
|
|
if h > w: |
|
|
ow = short_size |
|
|
oh = int(1.0 * h * ow / w) |
|
|
else: |
|
|
oh = short_size |
|
|
ow = int(1.0 * w * oh / h) |
|
|
img = img.resize((ow, oh), Image.BILINEAR) |
|
|
mask = mask.resize((ow, oh), Image.NEAREST) |
|
|
|
|
|
if short_size < self.crop_size: |
|
|
padh = self.crop_size - oh if oh < self.crop_size else 0 |
|
|
padw = self.crop_size - ow if ow < self.crop_size else 0 |
|
|
img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) |
|
|
mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=self.fill) |
|
|
|
|
|
w, h = img.size |
|
|
x1 = random.randint(0, w - self.crop_size) |
|
|
y1 = random.randint(0, h - self.crop_size) |
|
|
img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) |
|
|
mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) |
|
|
|
|
|
return img, mask |
|
|
|
|
|
|
|
|
class RandomRotation(object): |
|
|
def __init__(self, degrees, resample=False, expand=False, center=None): |
|
|
if isinstance(degrees, numbers.Number): |
|
|
if degrees < 0: |
|
|
raise ValueError("If degrees is a single number, it must be positive.") |
|
|
self.degrees = (-degrees, degrees) |
|
|
else: |
|
|
if len(degrees) != 2: |
|
|
raise ValueError("If degrees is a sequence, it must be of len 2.") |
|
|
self.degrees = degrees |
|
|
|
|
|
self.resample = resample |
|
|
self.expand = expand |
|
|
self.center = center |
|
|
|
|
|
@staticmethod |
|
|
def get_params(degrees): |
|
|
angle = random.uniform(degrees[0], degrees[1]) |
|
|
|
|
|
return angle |
|
|
|
|
|
def __call__(self, img, gt): |
|
|
""" |
|
|
img (PIL Image): Image to be rotated. |
|
|
|
|
|
Returns: |
|
|
PIL Image: Rotated image. |
|
|
""" |
|
|
|
|
|
angle = self.get_params(self.degrees) |
|
|
|
|
|
return F.rotate(img, angle, Image.BILINEAR, self.expand, self.center), F.rotate(gt, angle, Image.NEAREST, self.expand, self.center) |
|
|
|
|
|
|
|
|
|
|
|
class Compose(object): |
|
|
def __init__(self, transforms): |
|
|
self.transforms = transforms |
|
|
|
|
|
def __call__(self, img, gt): |
|
|
for t in self.transforms: |
|
|
img, gt = t(img, gt) |
|
|
return img, gt |
|
|
|
|
|
def __repr__(self): |
|
|
format_string = self.__class__.__name__ + '(' |
|
|
for t in self.transforms: |
|
|
format_string += '\n' |
|
|
format_string += ' {0}'.format(t) |
|
|
format_string += '\n)' |
|
|
return format_string |
|
|
|
|
|
|
|
|
|
|
|
def get_loader(img_root, gt_root, img_size, batch_size, max_num = float('inf'), istrain=True, shuffle=False, num_workers=0, pin=False): |
|
|
if istrain: |
|
|
transform = Compose([ |
|
|
RandomScaleCrop(img_size*2, img_size*2), |
|
|
FixedResize(img_size), |
|
|
RandomHorizontalFlip(), |
|
|
|
|
|
RandomRotation((-90, 90)), |
|
|
ToTensor(), |
|
|
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
else: |
|
|
transform = Compose([ |
|
|
FixedResize(img_size), |
|
|
|
|
|
ToTensor(), |
|
|
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
dataset = CoData(img_root, gt_root, img_size, transform, max_num, is_train=istrain) |
|
|
data_loader = data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, |
|
|
pin_memory=pin) |
|
|
return data_loader |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
mean = [0.485, 0.456, 0.406] |
|
|
std = [0.229, 0.224, 0.225] |
|
|
img_root = './data/testtrain/img/' |
|
|
gt_root = './data/testtrain/gt/' |
|
|
loader = get_loader(img_root, gt_root, 20, 1, 16, istrain=False) |
|
|
for batch in loader: |
|
|
b, c, h, w = batch[0][0].shape |
|
|
for i in range(b): |
|
|
img = batch[0].squeeze(0)[i].permute(1, 2, 0).cpu().numpy() * std + mean |
|
|
image = img * 255 |
|
|
mask = batch[1].squeeze(0)[i].squeeze().cpu().numpy() |
|
|
plt.subplot(121) |
|
|
plt.imshow(np.uint8(image)) |
|
|
plt.subplot(122) |
|
|
plt.imshow(mask) |
|
|
plt.show(block=True) |
|
|
|