|
|
from PIL import Image |
|
|
from .base_dataset import BaseDataset |
|
|
import numpy as np |
|
|
class WHUCD(BaseDataset): |
|
|
def __init__(self, data_root='data/WHU_CD', mode='train', transform=None, imgA_dir='image1', imgB_dir='image2', mask_dir='label', img_suffix='.png', mask_suffix='.png', **kwargs): |
|
|
super(WHUCD, self).__init__(transform, mode) |
|
|
|
|
|
self.imgA_dir = imgA_dir |
|
|
self.imgB_dir = imgB_dir |
|
|
self.img_suffix = img_suffix |
|
|
self.mask_dir = mask_dir |
|
|
self.mask_suffix = mask_suffix |
|
|
|
|
|
self.data_root = data_root + "/" + mode |
|
|
self.file_paths = self.get_path(self.data_root, imgA_dir, imgB_dir, mask_dir) |
|
|
|
|
|
|
|
|
self.color_map = { |
|
|
'NotChanged' : np.array([0, 0, 0]), |
|
|
'Changed' : np.array([255, 255, 255]), |
|
|
} |
|
|
|
|
|
self.num_classes = 2 |
|
|
|
|
|
def rgb2label(self,mask_rgb): |
|
|
|
|
|
mask_rgb = np.array(mask_rgb) |
|
|
_mask_rgb = mask_rgb.transpose(2, 0, 1) |
|
|
label_seg = np.zeros(_mask_rgb.shape[1:], dtype=np.uint8) |
|
|
label_seg[np.all(_mask_rgb.transpose([1, 2, 0]) == self.color_map['NotChanged'], axis=-1)] = 0 |
|
|
label_seg[np.all(_mask_rgb.transpose([1, 2, 0]) == self.color_map['Changed'], axis=-1)] = 1 |
|
|
|
|
|
_label_seg = Image.fromarray(label_seg).convert('L') |
|
|
return _label_seg |
|
|
|
|
|
|
|
|
|