import random from math import inf import cv2 import numpy as np import torch from floortrans.loaders import svg_utils class Compose(object): def __init__(self, augmentations): self.augmentations = augmentations def __call__(self, sample): for a in self.augmentations: sample = a(sample) return sample # 0. I # 1. I top to right # 2. I vertical flip # 3. I top to left # 4. L horizontal flip # 5. L # 6. L vertical flip # 7. L horizontal and vertical flip # 8. T # 9. T top to right # 10. T top to down # 11. T top to left # 12. X or + # 13. Opening left corner # 14. Opening right corner # 15. Opening up corner # 16. Opening down corer # 17. Icon upper left # 18. Icon upper right # 19. Icon lower left # 20. Icon lower right class RandomRotations(object): def __init__(self, format="furu"): if format == "furu": self.augment = self.furu elif format == "cubi": self.augment = self.cubi def __call__(self, sample): return self.augment(sample) def cubi(self, sample): fplan = sample["image"] segmentation = sample["label"] heatmap_points = sample["heatmaps"] scale = sample["scale"] num_of_rotations = int(torch.randint(0, 3, (1,))) hmapp_convert_map = { 0: 1, 1: 2, 2: 3, 3: 0, 4: 5, 5: 6, 6: 7, 7: 4, 8: 9, 9: 10, 10: 11, 11: 8, 12: 12, 13: 15, 14: 16, 15: 14, 16: 13, 17: 18, 18: 20, 19: 17, 20: 19, } for i in range(num_of_rotations): fplan = fplan.transpose(2, 1).flip(2) segmentation = segmentation.transpose(2, 1).flip(2) points_rotated = dict() for junction_type, points in heatmap_points.items(): new_junction_type = hmapp_convert_map[junction_type] new_heatmap_points = [] for point in points: x = fplan.shape[1] - 1 - point[1] y = point[0] # if y > 256 or x > 256: # __import__('ipdb').set_trace() new_heatmap_points.append([x, y]) points_rotated[new_junction_type] = new_heatmap_points heatmap_points = points_rotated sample = {"image": fplan, "label": segmentation, "scale": scale, "heatmaps": heatmap_points} return sample def furu(self, sample): fplan = sample["image"] segmentation = sample["label"] heatmap_points = sample["heatmap_points"] num_of_rotations = int(torch.randint(0, 3, (1,))) for i in range(num_of_rotations): fplan = fplan.transpose(2, 1).flip(2) segmentation = segmentation.transpose(2, 1).flip(2) hmapp_convert_map = { 0: 1, 1: 2, 2: 3, 3: 0, 4: 5, 5: 6, 6: 7, 7: 4, 8: 9, 9: 10, 10: 11, 11: 8, 12: 12, 13: 15, 14: 16, 15: 14, 16: 13, 17: 18, 18: 20, 19: 17, 20: 19, } points_rotated = dict() for junction_type, points in heatmap_points.items(): new_junction_type = hmapp_convert_map[junction_type] new_heatmap_points = [] for point in points: new_heatmap_points.append([fplan.shape[1] - 1 - point[1], point[0]]) points_rotated[new_junction_type] = new_heatmap_points heatmap_points = points_rotated sample = {"image": fplan, "label": segmentation, "heatmap_points": heatmap_points} return sample def clip_heatmaps(heatmaps, minx, maxx, miny, maxy): def clip(p): return p[0] < maxx and p[0] >= minx and p[1] < maxy and p[1] >= miny res = {} for key, value in heatmaps.items(): res[key] = list(filter(clip, value)) for i, e in enumerate(res[key]): res[key][i] = (e[0] - minx, e[1] - miny) return res class DictToTensor(object): def __init__(self, data_format="cubi"): if data_format == "cubi": self.augment = self.cubi elif data_format == "furukawa": self.augment = self.furukawa def __call__(self, sample): return self.augment(sample) def cubi(self, sample): image, label = sample["image"], sample["label"] _, height, width = label.shape heatmaps = sample["heatmaps"] scale = sample["scale"] heatmap_tensor = np.zeros((21, height, width)) for channel, coords in heatmaps.items(): for x, y in coords: if x >= width: x -= 1 if y >= height: y -= 1 heatmap_tensor[int(channel), int(y), int(x)] = 1 # Gaussian filter kernel = svg_utils.get_gaussian2D(int(30 * scale)) for i, h in enumerate(heatmap_tensor): heatmap_tensor[i] = cv2.filter2D(h, -1, kernel) heatmap_tensor = torch.FloatTensor(heatmap_tensor) label = torch.cat((heatmap_tensor, label), 0) return {"image": image, "label": label} def furukawa(self, sample): image, label = sample["image"], sample["label"] _, height, width = label.shape heatmap_points = sample["heatmap_points"] heatmap_tensor = np.zeros((21, height, width)) for channel, coords in heatmap_points.items(): for x, y in coords: heatmap_tensor[int(channel), int(y), int(x)] = 1 # Gaussian filter kernel = svg_utils.get_gaussian2D(13) for i, h in enumerate(heatmap_tensor): heatmap_tensor[i] = cv2.filter2D(h, -1, kernel, borderType=cv2.BORDER_CONSTANT, delta=0) heatmap_tensor = torch.FloatTensor(heatmap_tensor) label = torch.cat((heatmap_tensor, label), 0) return {"image": image, "label": label} class RotateNTurns(object): def rot_tensor(self, t, n): # One turn clock wise if n == 1: t = t.flip(2).transpose(3, 2) # One turn counter clock wise elif n == -1: t = t.transpose(3, 2).flip(2) # Two turns clock wise elif n == 2: t = t.flip(2).flip(3) return t def rot_points(self, t, n): # Swapping corner ts t_sorted = t.clone().detach() # One turn clock wise if n == 1: # I junctions t_sorted[:, 1] = t[:, 0] t_sorted[:, 2] = t[:, 1] t_sorted[:, 3] = t[:, 2] t_sorted[:, 0] = t[:, 3] # L junctions t_sorted[:, 5] = t[:, 4] t_sorted[:, 6] = t[:, 5] t_sorted[:, 7] = t[:, 6] t_sorted[:, 4] = t[:, 7] # T junctions t_sorted[:, 9] = t[:, 8] t_sorted[:, 10] = t[:, 9] t_sorted[:, 11] = t[:, 10] t_sorted[:, 8] = t[:, 11] # Opening corners t_sorted[:, 15] = t[:, 13] t_sorted[:, 16] = t[:, 14] t_sorted[:, 14] = t[:, 15] t_sorted[:, 13] = t[:, 16] # Icon corners t_sorted[:, 18] = t[:, 17] t_sorted[:, 20] = t[:, 18] t_sorted[:, 17] = t[:, 19] t_sorted[:, 19] = t[:, 20] # One turn counter clock wise elif n == -1: # I junctions t_sorted[:, 3] = t[:, 0] t_sorted[:, 0] = t[:, 1] t_sorted[:, 1] = t[:, 2] t_sorted[:, 2] = t[:, 3] # L junctions t_sorted[:, 7] = t[:, 4] t_sorted[:, 4] = t[:, 5] t_sorted[:, 5] = t[:, 6] t_sorted[:, 6] = t[:, 7] # T junctions t_sorted[:, 11] = t[:, 8] t_sorted[:, 8] = t[:, 9] t_sorted[:, 9] = t[:, 10] t_sorted[:, 10] = t[:, 11] # Opening corners t_sorted[:, 16] = t[:, 13] t_sorted[:, 15] = t[:, 14] t_sorted[:, 13] = t[:, 15] t_sorted[:, 14] = t[:, 16] # Icon corners t_sorted[:, 19] = t[:, 17] t_sorted[:, 17] = t[:, 18] t_sorted[:, 20] = t[:, 19] t_sorted[:, 18] = t[:, 20] # Two turns clock wise elif n == 2: t_sorted = t.clone().detach() # I junctions t_sorted[:, 2] = t[:, 0] t_sorted[:, 3] = t[:, 1] t_sorted[:, 0] = t[:, 2] t_sorted[:, 4] = t[:, 3] # L junctions t_sorted[:, 6] = t[:, 4] t_sorted[:, 7] = t[:, 5] t_sorted[:, 4] = t[:, 6] t_sorted[:, 5] = t[:, 7] # T junctions t_sorted[:, 10] = t[:, 8] t_sorted[:, 11] = t[:, 9] t_sorted[:, 8] = t[:, 10] t_sorted[:, 9] = t[:, 11] # Opening corners t_sorted[:, 14] = t[:, 13] t_sorted[:, 13] = t[:, 14] t_sorted[:, 16] = t[:, 15] t_sorted[:, 15] = t[:, 16] # Icon corners t_sorted[:, 20] = t[:, 17] t_sorted[:, 19] = t[:, 18] t_sorted[:, 18] = t[:, 19] t_sorted[:, 17] = t[:, 20] elif n == 0: return t_sorted return t_sorted def __call__(self, sample, data_type, n): if data_type == "tensor": return self.rot_tensor(sample, n) elif data_type == "points": return self.rot_points(sample, n) class RandomCropToSizeTorch(object): def __init__( self, input_slice=[21, 1, 1], size=(256, 256), fill=(0, 0), data_format="tensor", dtype=torch.float32, max_size=None, ): self.size = size self.width = size[0] self.height = size[1] self.dtype = dtype self.fill = fill self.max_size = max_size self.input_slice = input_slice if data_format == "dict": self.augment = self.augment_dict elif data_format == "tensor": self.augment = self.augment_tesor elif data_format == "dict furu": self.augment = self.augment_dict_furu def __call__(self, sample): return self.augment(sample) def augment_tesor(self, sample): image, label = sample["image"], sample["label"] img_w = image.shape[2] img_h = image.shape[1] pad_w = int(self.width / 2) pad_h = int(self.height / 2) new_w = self.width + max(img_w, self.width) new_h = self.height + max(img_h, self.height) new_image = torch.zeros([image.shape[0], new_h, new_w], dtype=self.dtype) new_image[:, pad_h : img_h + pad_h, pad_w : img_w + pad_w] = image new_heatmaps = torch.zeros([self.input_slice[0], new_h, new_w], dtype=self.dtype) new_heatmaps[:, pad_h : img_h + pad_h, pad_w : img_w + pad_w] = label[: self.input_slice[0]] new_rooms = torch.full((self.input_slice[1], new_h, new_w), self.fill[0]) new_rooms[:, pad_h : img_h + pad_h, pad_w : img_w + pad_w] = label[self.input_slice[0]] new_icons = torch.full((self.input_slice[2], new_h, new_w), self.fill[1]) new_icons[:, pad_h : img_h + pad_h, pad_w : img_w + pad_w] = label[self.input_slice[0] + self.input_slice[1]] label = torch.cat((new_heatmaps, new_rooms, new_icons), 0) image = new_image removed_up = random.randint(0, new_h - self.width) removed_left = random.randint(0, new_w - self.height) removed_down = new_h - self.height - removed_up removed_right = new_w - self.width - removed_left if removed_down == 0 and removed_right == 0: image = image[:, removed_up:, removed_left:] label = label[:, removed_up:, removed_left:] elif removed_down == 0: image = image[:, removed_up:, removed_left:-removed_right] label = label[:, removed_up:, removed_left:-removed_right] elif removed_right == 0: image = image[:, removed_up:-removed_down, removed_left:] label = label[:, removed_up:-removed_down, removed_left:] else: image = image[:, removed_up:-removed_down, removed_left:-removed_right] label = label[:, removed_up:-removed_down, removed_left:-removed_right] return {"image": image, "label": label} def augment_dict(self, sample): image, label = sample["image"], sample["label"] heatmap_points = sample["heatmaps"] img_w = image.shape[2] img_h = image.shape[1] pad_w = int(self.width / 2) pad_h = int(self.height / 2) new_w = self.width + img_w new_h = self.height + img_h new_image = torch.full([image.shape[0], new_h, new_w], 255) new_image[:, pad_h : img_h + pad_h, pad_w : img_w + pad_w] = image new_rooms = torch.full((1, new_h, new_w), self.fill[0]) new_rooms[:, pad_h : img_h + pad_h, pad_w : img_w + pad_w] = label[0] new_icons = torch.full((1, new_h, new_w), self.fill[1]) new_icons[:, pad_h : img_h + pad_h, pad_w : img_w + pad_w] = label[1] label = torch.cat((new_rooms, new_icons), 0) image = new_image removed_up = random.randint(0, new_h - self.width) removed_left = random.randint(0, new_w - self.height) removed_down = new_h - self.height - removed_up removed_right = new_w - self.width - removed_left new_heatmap_points = dict() for junction_type, points in heatmap_points.items(): new_heatmap_points_per_type = [] for point in points: new_heatmap_points_per_type.append([point[0] + pad_w, point[1] + pad_h]) new_heatmap_points[junction_type] = new_heatmap_points_per_type heatmap_points = new_heatmap_points if removed_down == 0 and removed_right == 0: image = image[:, removed_up:, removed_left:] label = label[:, removed_up:, removed_left:] heatmap_points = clip_heatmaps(heatmap_points, removed_left, inf, removed_up, inf) elif removed_down == 0: image = image[:, removed_up:, removed_left:-removed_right] label = label[:, removed_up:, removed_left:-removed_right] heatmap_points = clip_heatmaps(heatmap_points, removed_left, removed_left + self.width, removed_up, inf) elif removed_right == 0: image = image[:, removed_up:-removed_down, removed_left:] label = label[:, removed_up:-removed_down, removed_left:] heatmap_points = clip_heatmaps(heatmap_points, removed_left, inf, removed_up, removed_up + self.width) else: image = image[:, removed_up:-removed_down, removed_left:-removed_right] label = label[:, removed_up:-removed_down, removed_left:-removed_right] heatmap_points = clip_heatmaps( heatmap_points, removed_left, removed_left + self.width, removed_up, removed_up + self.height ) return {"image": image, "label": label, "heatmaps": heatmap_points, "scale": sample["scale"]} def augment_dict_furu(self, sample): image, label = sample["image"], sample["label"] heatmap_points = sample["heatmap_points"] img_w = image.shape[2] img_h = image.shape[1] pad_w = int(self.width / 2) pad_h = int(self.height / 2) new_w = self.width + img_w new_h = self.height + img_h new_image = torch.full([image.shape[0], new_h, new_w], 255) new_image[:, pad_h : img_h + pad_h, pad_w : img_w + pad_w] = image new_rooms = torch.full((1, new_h, new_w), self.fill[0]) new_rooms[:, pad_h : img_h + pad_h, pad_w : img_w + pad_w] = label[0] new_icons = torch.full((1, new_h, new_w), self.fill[1]) new_icons[:, pad_h : img_h + pad_h, pad_w : img_w + pad_w] = label[1] label = torch.cat((new_rooms, new_icons), 0) image = new_image removed_up = random.randint(0, new_h - self.width) removed_left = random.randint(0, new_w - self.height) removed_down = new_h - self.height - removed_up removed_right = new_w - self.width - removed_left new_heatmap_points = dict() for junction_type, points in heatmap_points.items(): new_heatmap_points_per_type = [] for point in points: new_heatmap_points_per_type.append([point[0] + pad_w, point[1] + pad_h]) new_heatmap_points[junction_type] = new_heatmap_points_per_type heatmap_points = new_heatmap_points if removed_down == 0 and removed_right == 0: image = image[:, removed_up:, removed_left:] label = label[:, removed_up:, removed_left:] heatmap_points = clip_heatmaps(heatmap_points, removed_left, inf, removed_up, inf) elif removed_down == 0: image = image[:, removed_up:, removed_left:-removed_right] label = label[:, removed_up:, removed_left:-removed_right] heatmap_points = clip_heatmaps(heatmap_points, removed_left, removed_left + self.width, removed_up, inf) elif removed_right == 0: image = image[:, removed_up:-removed_down, removed_left:] label = label[:, removed_up:-removed_down, removed_left:] heatmap_points = clip_heatmaps(heatmap_points, removed_left, inf, removed_up, removed_up + self.width) else: image = image[:, removed_up:-removed_down, removed_left:-removed_right] label = label[:, removed_up:-removed_down, removed_left:-removed_right] heatmap_points = clip_heatmaps( heatmap_points, removed_left, removed_left + self.width, removed_up, removed_up + self.height ) return {"image": image, "label": label, "heatmap_points": heatmap_points} class ColorJitterTorch(object): def __init__(self, b_var=0.4, c_var=0.4, s_var=0.4, dtype=torch.float32, version="dict"): self.b_var = b_var self.c_var = c_var self.s_var = s_var self.dtype = dtype self.version = version def __call__(self, sample): res = sample image = sample["image"] image = self.brightness(image, self.b_var) image = self.contrast(image, self.c_var) image = self.saturation(image, self.s_var) res["image"] = image return res def blend(self, img_1, img_2, var): m = torch.tensor([0], dtype=self.dtype).uniform_(-var, var) alpha = 1 + m res = img_1 * alpha + (1 - alpha) * img_2 res = torch.clamp(res, min=0, max=255) return res def grayscale(self, img): red = img[0] * 0.299 green = img[1] * 0.587 blue = img[2] * 0.114 gray = red + green + blue gray = torch.clamp(gray, min=0, max=255) res = torch.stack((gray, gray, gray), dim=0) return res def saturation(self, img, var): res = self.grayscale(img) res = self.blend(img, res, var) return res def brightness(self, img, var): res = torch.zeros(img.shape) res = self.blend(img, res, var) return res def contrast(self, img, var): res = self.grayscale(img) mean_color = res.mean() res = torch.full(res.shape, mean_color) res = self.blend(img, res, var) return res class ResizePaddedTorch(object): def __init__(self, fill, size=(256, 256), both=True, dtype=torch.float32, data_format="tensor"): self.size = size self.width = size[0] self.height = size[1] self.both = both self.dtype = dtype self.fill = fill self.fill_cval = 255 if data_format == "tensor": self.augment = self.augment_tensor elif data_format == "dict furu": self.augment = self.augment_dict_furu elif data_format == "dict": self.augment = self.augment_dict self.fill_cval = 1 def __call__(self, sample): # image 1: Bi-cubic interpolation as in original paper image, _, _, _ = self.resize_padded( sample["image"], self.size, fill_cval=self.fill_cval, image=True, mode="bilinear", aling_corners=False ) sample["image"] = image return self.augment(sample) def augment_tensor(self, sample): image, label = sample["image"], sample["label"] if self.both: # labels 0: Nearest-neighbor interpolation heatmaps, _, _, _ = self.resize_padded(label[:21], self.size, mode="bilinear", aling_corners=False) rooms_padded, _, _, _ = self.resize_padded(label[[21]], self.size, mode="nearest", fill_cval=self.fill[0]) icons_padded, _, _, _ = self.resize_padded( label[[22]], self.size, mode="nearest", fill_cval=self.fill[1], ) label = torch.cat((heatmaps, rooms_padded, icons_padded), dim=0) return {"image": image, "label": label} def augment_dict_furu(self, sample): image, label = sample["image"], sample["label"] heatmap_points = sample["heatmap_points"] rooms_padded, _, _, _ = self.resize_padded(label[[0]], self.size, mode="nearest", fill_cval=self.fill[0]) icons_padded, ratio, y_pad, x_pad = self.resize_padded( label[[1]], self.size, mode="nearest", fill_cval=self.fill[1] ) label = torch.cat((rooms_padded, icons_padded), dim=0) new_heatmap_points = dict() for junction_type, points in heatmap_points.items(): new_heatmap_points_per_type = [] for point in points: # Indexing starts from 0 but when we multiply with the ratio we need to start from 0. new_x = point[0] * ratio + x_pad new_y = point[1] * ratio + y_pad new_heatmap_points_per_type.append([new_x, new_y]) new_heatmap_points[junction_type] = new_heatmap_points_per_type heatmap_points = new_heatmap_points return {"image": image, "label": label, "heatmap_points": heatmap_points} def augment_dict(self, sample): image, label = sample["image"], sample["label"] heatmap_points = sample["heatmaps"] scale = sample["scale"] rooms_padded, _, _, _ = self.resize_padded(label[[0]], self.size, mode="nearest", fill_cval=self.fill[0]) icons_padded, ratio, y_pad, x_pad = self.resize_padded( label[[1]], self.size, mode="nearest", fill_cval=self.fill[1] ) label = torch.cat((rooms_padded, icons_padded), dim=0) new_heatmap_points = dict() for junction_type, points in heatmap_points.items(): new_heatmap_points_per_type = [] for point in points: # Indexing starts from 0 but when we multiply with the ratio we need to start from 0. new_x = point[0] * ratio + x_pad new_y = point[1] * ratio + y_pad if new_y < 256 and new_x < 256 and new_y >= 0 and new_x >= 0: # __import__('ipdb').set_trace() new_heatmap_points_per_type.append([new_x, new_y]) new_heatmap_points[junction_type] = new_heatmap_points_per_type heatmap_points = new_heatmap_points return {"image": image, "label": label, "heatmaps": heatmap_points, "scale": scale} def resize_padded(self, img, new_shape, image=False, fill_cval=0, mode="nearest", aling_corners=None): new_shape = torch.tensor([img.shape[0], new_shape[0], new_shape[1]], dtype=self.dtype) old_shape = torch.tensor(img.shape, dtype=self.dtype) ratio = (new_shape / old_shape).min() img_s = torch.tensor(img.shape[1:], dtype=self.dtype) interm_shape = (ratio * img_s).ceil() interm_shape = [interm_shape[0], interm_shape[1]] img = img.unsqueeze(0) interm_img = torch.nn.functional.interpolate(img, size=interm_shape, mode=mode, align_corners=aling_corners) interm_img = interm_img.squeeze(0) a = (interm_img.shape[0], self.size[0], self.size[1]) new_img = torch.full(a, fill_cval) x_pad = int((self.width - interm_img.shape[1]) / 2) y_pad = int((self.height - interm_img.shape[2]) / 2) new_img[:, x_pad : interm_img.shape[1] + x_pad, y_pad : interm_img.shape[2] + y_pad] = interm_img return new_img, ratio, x_pad, y_pad