Thundernet / data_gen.py
ExtendedRealityLab's picture
Add files using upload-large-folder tool
ae29340 verified
from pathlib import Path
import numpy as np
import cv2
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.compat.v1 as tf1
from torch.utils.data import Dataset
import random
import torch.nn.functional as F
class ImageHelper:
COLOR_TRANSFORMATIONS = [
"saturation",
"contrast",
"brightness",
]
def __init__(self, img_path, label_path, output_size, **kwargs):
self.img_path = img_path
self.label_path = label_path
self.output_size = output_size
self.kwargs = kwargs
# Stereo
self.to_stereo = False
if "to_stereo" in kwargs.keys() and kwargs["to_stereo"]:
self.to_stereo = True
# Flip
self.flip = False
if "flip" in kwargs.keys() and kwargs["flip"]:
self.flip = True
# Color transformations
self.color_transformations = []
for k, v in self.kwargs.items():
if k in self.COLOR_TRANSFORMATIONS and v:
self.color_transformations.append(k)
def get(self):
# Load
img = cv2.imread(str(self.img_path))
label = cv2.imread(str(self.label_path))
# Size checking
assert img.shape == label.shape
# Flip
if self.flip:
img, label = self.apply_transformation("flip", img, label)
# Color transformations
for color_tr in self.color_transformations:
img, label = self.apply_transformation(color_tr, img, label)
# Numpy3333
if type(img) != np.ndarray:
img = np.array(img)
if type(label) != np.ndarray:
label = np.array(label)
# To stereo
if self.to_stereo:
img = np.concatenate((img, img), axis=1)
label = np.concatenate((label, label), axis=1)
# Size
img = cv2.resize(img, self.output_size[::-1])
label = cv2.resize(
label, self.output_size[::-1], interpolation=cv2.INTER_NEAREST
)
label = label[:, :, 0]
return img, label
@classmethod
def apply_transformation(cls, transformation, img, label):
if transformation == "flip":
return cls.tensor_to_numpy(
tf.image.flip_left_right(img)
), cls.tensor_to_numpy(tf.image.flip_left_right(label))
elif transformation == "saturation":
return cls.tensor_to_numpy(tf.image.random_saturation(img, 0.5, 1.5)), label
elif transformation == "contrast":
return cls.tensor_to_numpy(tf.image.random_contrast(img, 0.5, 1.5)), label
elif transformation == "brightness":
return cls.tensor_to_numpy(tf.image.random_brightness(img, 0.3)), label
elif transformation == "rotate":
raise ValueError("This transformation is not supported yet")
elif transformation == "directed_crop":
raise ValueError("This transformation is not supported")
@staticmethod
def tensor_to_numpy(tensor):
if tf.executing_eagerly():
a = tensor.numpy()
else:
raise NotImplementedError(
"Please adapt the Data Generator to work when not executing eagerly"
)
return a
class DataGenerator(keras.utils.Sequence):
def __init__(
self,
images_path,
labels_path,
n_classes,
batch_size=32,
output_size=(480, 640),
to_stereo=False,
flip=False,
saturation=False,
contrast=False,
brightness=False,
class_mappings=None,
):
self.images_path = Path(images_path)
self.labels_path = Path(labels_path)
self.n_classes = n_classes
self.batch_size = batch_size
self.output_size = output_size
self.to_stereo = to_stereo
self.class_mappings = class_mappings
# Check image and labels dir
img_paths = sorted(list(self.images_path.iterdir()))
def has_label(img_filename):
return (self.labels_path / f"{img_filename.stem}.png").exists()
if not all(map(has_label, img_paths)):
raise FileNotFoundError("Check every image has a label")
# Obtain transformations
transformations = []
if flip:
transformations.append("flip")
if saturation:
transformations.append("saturation")
if contrast:
transformations.append("contrast")
if brightness:
transformations.append("brightness")
# Prepare augmentation
elements = []
for image_path in img_paths:
label_path = self.labels_path / f"{image_path.stem}.png"
elements.append(
ImageHelper(
image_path,
label_path,
self.output_size,
to_stereo=self.to_stereo,
)
)
for tr in transformations:
elements.append(
ImageHelper(
image_path,
label_path,
self.output_size,
to_stereo=self.to_stereo,
**{tr: True},
)
)
self.elements = elements
# Shuffle
np.random.shuffle(self.elements)
def __getitem__(self, idx):
batch_elements = self.elements[
idx * self.batch_size : (idx + 1) * self.batch_size
]
batch_elements_tuple = list(map(lambda x: x.get(), batch_elements))
X, y = zip(*batch_elements_tuple)
X, y = np.array(X), np.array(y)
y_onehot = np.zeros(y.shape + (self.n_classes,))
for i in np.unique(y):
i = int(i)
idx_for_this_class = np.where(y == i)
if self.class_mappings:
y_onehot[
idx_for_this_class
+ (
np.ones(len(idx_for_this_class[0]), dtype=int)
* self.class_mappings[i],
)
] = 1
else:
y_onehot[
idx_for_this_class
+ (np.ones(len(idx_for_this_class[0]), dtype=int) * i,)
] = 1
final_X, final_y = X.astype(np.float64) / 255, y_onehot.astype(bool)
# assert final_X.shape[:-1] == final_y.shape[:-1]
return final_X, final_y
def get_item_name(self, idx):
return self.elements[idx].img_path.stem
def __len__(self):
try:
return np.int(len(self.elements) / self.batch_size)
except AttributeError:
return int(len(self.elements) / self.batch_size)
def on_epoch_end(self):
np.random.shuffle(self.elements)
@classmethod
def create_generators(
cls,
dataset_dir,
n_classes,
training_batch_size=32,
validation_batch_size=8,
output_size=(480, 640),
to_stereo=False,
transformations=tuple(),
class_mappings=None,
):
"""
Utily method to create both generators
Args:
dataset_dir: path of the dataset, must have training and val dirs
training_batch_size: batch size of the training generator
output_size: shape of the generated images
transformations: for data agumentations
to_stereo: whether the image and label must be converted to stereo
class_mappings: dict containing a mapping for each class
Returns: a tuple with the training and val genearators
"""
dataset_dir = Path(dataset_dir)
training_generator = cls(
dataset_dir / "training" / "images",
dataset_dir / "training" / "labels",
n_classes,
batch_size=training_batch_size,
output_size=output_size,
to_stereo=to_stereo,
**{tr: True for tr in transformations},
class_mappings=class_mappings,
)
validation_generator = cls(
dataset_dir / "val" / "images",
dataset_dir / "val" / "labels",
n_classes,
batch_size=validation_batch_size,
output_size=output_size,
to_stereo=to_stereo,
**{tr: True for tr in transformations},
class_mappings=class_mappings,
)
return training_generator, validation_generator
y_k_size = 6
x_k_size = 6
class BaseDataset(Dataset):
def __init__(
self,
ignore_label=255,
base_size=2048,
crop_size=(512, 1024),
scale_factor=16,
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
):
self.base_size = base_size
self.crop_size = crop_size
self.ignore_label = ignore_label
self.mean = mean
self.std = std
self.scale_factor = scale_factor
self.files = []
def __len__(self):
return len(self.files)
def input_transform(self, image, city=True):
if city:
image = image.astype(np.float32)[:, :, ::-1]
else:
image = image.astype(np.float32)
image = image / 255.0
image -= self.mean
image /= self.std
return image
def label_transform(self, label):
return np.array(label).astype(np.uint8)
def pad_image(self, image, h, w, size, padvalue):
pad_image = image.copy()
pad_h = max(size[0] - h, 0)
pad_w = max(size[1] - w, 0)
if pad_h > 0 or pad_w > 0:
pad_image = cv2.copyMakeBorder(
image, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=padvalue
)
return pad_image
def rand_crop(self, image, label, edge):
h, w = image.shape[:-1]
image = self.pad_image(image, h, w, self.crop_size, (0.0, 0.0, 0.0))
label = self.pad_image(label, h, w, self.crop_size, (self.ignore_label,))
edge = self.pad_image(edge, h, w, self.crop_size, (0.0,))
new_h, new_w = label.shape
x = random.randint(0, new_w - self.crop_size[1])
y = random.randint(0, new_h - self.crop_size[0])
image = image[y : y + self.crop_size[0], x : x + self.crop_size[1]]
label = label[y : y + self.crop_size[0], x : x + self.crop_size[1]]
edge = edge[y : y + self.crop_size[0], x : x + self.crop_size[1]]
return image, label, edge
def multi_scale_aug(
self, image, label=None, edge=None, rand_scale=1, rand_crop=True
):
long_size = np.int(self.base_size * rand_scale + 0.5)
h, w = image.shape[:2]
if h > w:
new_h = long_size
new_w = np.int(w * long_size / h + 0.5)
else:
new_w = long_size
new_h = np.int(h * long_size / w + 0.5)
image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
if label is not None:
label = cv2.resize(label, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
if edge is not None:
edge = cv2.resize(edge, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
else:
return image
if rand_crop:
image, label, edge = self.rand_crop(image, label, edge)
return image, label, edge
def gen_sample(
self,
image,
label,
multi_scale=True,
is_flip=True,
edge_pad=True,
edge_size=4,
city=False,
):
edge = cv2.Canny(label, 0.1, 0.2)
kernel = np.ones((edge_size, edge_size), np.uint8)
if edge_pad:
edge = edge[y_k_size:-y_k_size, x_k_size:-x_k_size]
edge = np.pad(
edge, ((y_k_size, y_k_size), (x_k_size, x_k_size)), mode="constant"
)
edge = (cv2.dilate(edge, kernel, iterations=1) > 50) * 1.0
if multi_scale:
rand_scale = 0.5 + random.randint(0, self.scale_factor) / 10.0
image, label, edge = self.multi_scale_aug(
image, label, edge, rand_scale=rand_scale
)
image = self.input_transform(image, city=city)
label = self.label_transform(label)
image = image.transpose((2, 0, 1))
if is_flip:
flip = np.random.choice(2) * 2 - 1
image = image[:, :, ::flip]
label = label[:, ::flip]
edge = edge[:, ::flip]
return image, label, edge
def inference(self, config, model, image):
size = image.size()
pred = model(image)
if config.MODEL.NUM_OUTPUTS > 1:
pred = pred[config.TEST.OUTPUT_INDEX]
pred = F.interpolate(
input=pred,
size=size[-2:],
mode="bilinear",
align_corners=config.MODEL.ALIGN_CORNERS,
)
return pred.exp()
class PIDNetDataset(BaseDataset):
def __init__(
self,
images_path,
labels_path,
n_classes,
output_size=(480, 640),
to_stereo=False,
flip=False,
saturation=False,
contrast=False,
brightness=False,
class_mappings=None,
multi_scale=True,
ignore_label=255,
base_size=2048,
crop_size=(512, 1024),
scale_factor=16,
# mean=[0.485, 0.456, 0.406],
# std=[0.229, 0.224, 0.225],
mean=[0.342, 0.374, 0.416],
std=[0.241, 0.239, 0.253],
bd_dilate_size=4,
):
super(PIDNetDataset, self).__init__(
ignore_label, base_size, crop_size, scale_factor, mean, std
)
self.images_path = Path(images_path)
self.labels_path = Path(labels_path)
self.n_classes = n_classes
self.output_size = output_size
self.to_stereo = to_stereo
self.class_mappings = class_mappings
self.bd_dilate_size = bd_dilate_size
self.multi_scale = multi_scale
self.flip = flip
# Check image and labels dir
img_paths = sorted(list(self.images_path.iterdir()))
def has_label(img_filename):
return (self.labels_path / f"{img_filename.stem}.png").exists()
if not all(map(has_label, img_paths)):
raise FileNotFoundError("Check every image has a label")
# Obtain transformations
transformations = []
# if flip:
# transformations.append('flip')
if saturation:
transformations.append("saturation")
if contrast:
transformations.append("contrast")
if brightness:
transformations.append("brightness")
# Prepare augmentation
elements = []
for image_path in img_paths:
label_path = self.labels_path / f"{image_path.stem}.png"
elements.append(
ImageHelper(
image_path,
label_path,
self.output_size,
to_stereo=self.to_stereo,
)
)
for tr in transformations:
elements.append(
ImageHelper(
image_path,
label_path,
self.output_size,
to_stereo=self.to_stereo,
**{tr: True},
)
)
self.elements = elements
def __len__(self):
return len(self.elements)
def __getitem__(self, idx):
element = self.elements[idx]
name = element.img_path.stem
X, y = element.get()
# Class mappings
if self.class_mappings:
y = np.vectorize(lambda x: self.class_mappings[x])(y).astype(np.uint8)
y_onehot = np.zeros(y.shape + (self.n_classes,))
for i in np.unique(y):
i = int(i)
idx_for_this_class = np.where(y == i)
if self.class_mappings:
y_onehot[
idx_for_this_class
+ (
np.ones(len(idx_for_this_class[0]), dtype=int)
* self.class_mappings[i],
)
] = 1
else:
y_onehot[
idx_for_this_class
+ (np.ones(len(idx_for_this_class[0]), dtype=int) * i,)
] = 1
# assert final_X.shape[:-1] == final_y.shape[:-1]
image, label = X, y
image, label, edge = self.gen_sample(
image, label, self.multi_scale, self.flip, edge_size=self.bd_dilate_size
)
return image.copy(), label.copy(), edge.copy(), np.array(image.shape), name
@classmethod
def create_train_and_test_datasets(
cls,
dataset_dir,
n_classes,
output_size=(480, 640),
to_stereo=False,
transformations=tuple(),
class_mappings=None,
):
dataset_dir = Path(dataset_dir)
training_generator = cls(
dataset_dir / "training" / "images",
dataset_dir / "training" / "labels",
n_classes,
output_size=output_size,
to_stereo=to_stereo,
**{tr: True for tr in transformations},
class_mappings=class_mappings,
)
validation_generator = cls(
dataset_dir / "val" / "images",
dataset_dir / "val" / "labels",
n_classes,
output_size=output_size,
to_stereo=to_stereo,
# **{tr: True for tr in transformations}
class_mappings=class_mappings,
)
return training_generator, validation_generator
class MergedDataset(Dataset):
def __init__(self, *datasets):
self.datasets = datasets
for d in self.datasets:
assert isinstance(d, Dataset)
self.lens = [len(d) for d in self.datasets]
self.acc_lens = [sum(self.lens[: i + 1]) for i in range(len(self.lens))]
def __len__(self):
return sum(self.lens)
def __getitem__(self, idx):
for i in range(len(self.acc_lens)):
if idx < self.acc_lens[i]:
diff = self.acc_lens[i - 1] if i != 0 else 0
s = self.datasets[i][idx - diff]
# assert s[1].max() <= 3
# assert s[1].max() <= 3
return s
raise ValueError(
f"idx out of range, was {idx}, should be less than {self.__len__()}"
)
if __name__ == "__main__":
"""
dataset_dir = Path('/home/user/nas/Datasets/egocentric_segmentation/joint-ep-of-thu-ego-for-5-office-objects/')
helper = ImageHelper(
dataset_dir / 'training' / 'images' / 'L515_020_003_rgb_0246.jpg',
dataset_dir / 'training' / 'labels' / 'L515_020_003_rgb_0246.png',
(480, 640),
to_stereo=True
)
image, label = helper.get()
"""
gen = DataGenerator(
Path(
"C:/Users/xruser/RealTimeSemanticSegmentation/joint-ep-of-thu-ego-stereo-1280x480/joint-ep-of-thu-ego-stereo-1280x480/"
)
/ "pruned_training"
/ "images",
Path(
"C:/Users/xruser/RealTimeSemanticSegmentation/joint-ep-of-thu-ego-stereo-1280x480/joint-ep-of-thu-ego-stereo-1280x480/"
)
/ "pruned_training"
/ "labels",
7,
batch_size=4,
to_stereo=True,
)
images, labels = gen[0]
print("hola")