import argparse import os import random from collections import defaultdict from pathlib import Path import cv2 import numpy as np import plotly.graph_objects as go import torch from torch.utils.data import DataLoader from datasets import build_dataset from datasets.data_utils import sort_polygons from util.plot_utils import ( CC5K_LABEL, S3D_LABEL, auto_crop_whitespace, plot_room_map, plot_semantic_rich_floorplan_opencv, plot_semantic_rich_floorplan_tight, ) def unnormalize_image(x): mean = np.array([0.485, 0.456, 0.406]) std = np.array([0.229, 0.224, 0.225]) return x * std + mean def plot_gt_floor( args, data_loader, device, output_dir, plot_gt=True, semantic_rich=False, dataset_name="cubicasa", crop_white_space=False, ): if not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True) semantics_label_mapping = None if args.dataset_name == "stru3d": door_window_index = [16, 17] semantics_label_mapping = S3D_LABEL elif args.dataset_name == "cubicasa": door_window_index = [10, 9] semantics_label_mapping = CC5K_LABEL elif args.dataset_name == "waffle": door_window_index = [1, 2] else: door_window_index = [] for batched_inputs, _ in data_loader: samples = [x["image"].to(device) for x in batched_inputs] scene_ids = [x["image_id"] for x in batched_inputs] gt_instances = [x["instances"].to(device) for x in batched_inputs] # draw GT map if plot_gt: for i, gt_inst in enumerate(gt_instances): image = np.transpose((samples[i] * 255).cpu().numpy(), [1, 2, 0]) if not semantic_rich: # plot regular room floorplan gt_polys = [] density_map = np.transpose((samples[i] * 255).cpu().numpy(), [1, 2, 0]) density_map = np.repeat(density_map, 3, axis=2) for j, poly in enumerate(gt_inst.gt_masks.polygons): corners = poly[0].reshape(-1, 2) if len(corners) < 3: continue gt_polys.append(corners) gt_room_polys = [np.array(r) for r in gt_polys] gt_polygons_labels = gt_inst.gt_classes.cpu().numpy() gt_sem_rich = [] for poly, poly_type in zip(gt_room_polys, gt_polygons_labels): gt_sem_rich.append([poly, poly_type]) if args.plot_engine == "opencv": gt_sem_rich_path: str = os.path.join( output_dir, "{}_floor.png".format(str(scene_ids[i]).zfill(5)) ) gt_floorplan_map = plot_semantic_rich_floorplan_opencv( gt_sem_rich, None, door_window_index=door_window_index, img_w=args.image_size * args.image_scale, img_h=args.image_size * args.image_scale, semantics_label_mapping=semantics_label_mapping, plot_text=False, scale=args.image_scale, is_sem=True, one_color=args.one_color, is_bw=args.is_bw, ) if crop_white_space: image = cv2.resize( image, (args.image_scale * args.image_size, args.image_scale * args.image_size), interpolation=cv2.INTER_NEAREST, ) image, cropped_box = auto_crop_whitespace(image) _x, _y, _w, _h = [ele for ele in cropped_box] gt_floorplan_map = gt_floorplan_map[_y : _y + _h, _x : _x + _w].copy() cv2.imwrite(gt_sem_rich_path, gt_floorplan_map, [cv2.IMWRITE_PNG_COMPRESSION, 0]) else: gt_sem_rich_path = os.path.join(output_dir, "{}.png".format(str(scene_ids[i]).zfill(5))) plot_semantic_rich_floorplan_tight( gt_sem_rich, gt_sem_rich_path, None, None, plot_text=False, is_bw=args.is_bw, door_window_index=door_window_index, img_w=args.image_size * args.image_scale, img_h=args.image_size * args.image_scale, ) else: # plot semantically-rich floorplan gt_polygons_labels = gt_inst.gt_classes.cpu().numpy() gt_polygons = gt_inst.gt_masks.polygons gt_polygons, sorted_indices = sort_polygons(gt_polygons) gt_polygons_labels = [gt_polygons_labels[_idx] for _idx in sorted_indices] gt_sem_rich = [] for j, (poly, poly_label) in enumerate(zip(gt_polygons, gt_polygons_labels)): # if gt_inst.gt_classes.cpu().numpy()[j] not in [1, 9, 11]: # continue corners = poly[0].reshape(-1, 2).astype(np.int32) # corners_flip_y = corners.copy() # corners_flip_y[:,1] = 255 - corners_flip_y[:,1] # corners = corners_flip_y gt_sem_rich.append([corners, poly_label]) if args.plot_engine == "opencv": gt_sem_rich_path = os.path.join(output_dir, "{}_floor.png".format(str(scene_ids[i]).zfill(5))) gt_floorplan_map = plot_semantic_rich_floorplan_opencv( gt_sem_rich, None, door_window_index=door_window_index, semantics_label_mapping=semantics_label_mapping, scale=args.image_scale, img_w=args.image_size * args.image_scale, img_h=args.image_size * args.image_scale, is_bw=args.is_bw, plot_text=False, one_color=args.one_color, ) if crop_white_space: image, cropped_box = auto_crop_whitespace(image) _x, _y, _w, _h = [ele * args.image_scale for ele in cropped_box] gt_floorplan_map = gt_floorplan_map[_y : _y + _h, _x : _x + _w].copy() cv2.imwrite(gt_sem_rich_path, gt_floorplan_map, [cv2.IMWRITE_PNG_COMPRESSION, 0]) else: gt_sem_rich_path = os.path.join(output_dir, "{}.png".format(str(scene_ids[i]).zfill(5))) plot_semantic_rich_floorplan_tight( gt_sem_rich, gt_sem_rich_path, None, None, plot_text=False, is_bw=args.is_bw, door_window_index=door_window_index, img_w=args.image_size * args.image_scale, img_h=args.image_size * args.image_scale, ) def plot_polys(data_loader, device, output_dir): if not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True) for batched_inputs, _ in data_loader: samples = [x["image"].to(device) for x in batched_inputs] scene_ids = [x["image_id"] for x in batched_inputs] gt_instances = [x["instances"].to(device) for x in batched_inputs] for i in range(len(samples)): density_map = np.transpose((samples[i]).cpu().numpy(), [1, 2, 0]) if density_map.shape[2] == 3: density_map = density_map * 255 else: density_map = np.repeat(density_map, 3, axis=2) * 255 pred_room_map = np.zeros(density_map.shape).astype(np.uint8) room_polys = gt_instances[i].gt_masks.polygons room_ids = gt_instances[i].gt_classes.detach().cpu().numpy() for poly, poly_id in zip(room_polys, room_ids): poly = poly[0].reshape(-1, 2).astype(np.int32) pred_room_map = plot_room_map(poly, pred_room_map, poly_id) # Blend the overlay with the density map using alpha blending alpha = 0.6 # Adjust for desired transparency pred_room_map = cv2.addWeighted( density_map.astype(np.uint8), alpha, pred_room_map.astype(np.uint8), 1 - alpha, 0 ) # # plot predicted polygon overlaid on the density map # pred_room_map = np.clip(pred_room_map + density_map, 0, 255) cv2.imwrite(os.path.join(output_dir, "{}_pred_room_map.png".format(scene_ids[i])), pred_room_map) def plot_gt_image(data_loader, device, output_dir, crop_white_space=False): if not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True) for batched_inputs, _ in data_loader: samples = [x["image"].to(device) for x in batched_inputs] scene_ids = [x["image_id"] for x in batched_inputs] for i in range(len(samples)): density_map = np.transpose((samples[i]).cpu().numpy(), [1, 2, 0]) if density_map.shape[2] == 3: density_map = density_map * 255 else: density_map = np.repeat(density_map, 3, axis=2) * 255 if crop_white_space: density_map = cv2.resize( density_map, (args.image_scale * args.image_size, args.image_scale * args.image_size), interpolation=cv2.INTER_NEAREST, ) density_map, _ = auto_crop_whitespace(image=density_map, color_invert=True) cv2.imwrite(os.path.join(output_dir, "{}_gt_image.png".format(scene_ids[i])), density_map) def plot_histogram(count_dict, title, output_path, bin_size=10): # Group keys into bins based on the bin_size binned_count_dict = {} for key, value in count_dict.items(): bin_key = (key // bin_size) * bin_size # Determine the bin for the key binned_count_dict[bin_key] = binned_count_dict.get(bin_key, 0) + value # Sort the bins binned_keys = sorted(binned_count_dict.keys()) binned_values = [binned_count_dict[key] for key in binned_keys] # Determine the maximum value for the y-axis max_y = max(binned_values) # Adjust y-axis ticks dynamically for large ranges tick_interval = max(1, max_y // 10) # Divide the range into 10 intervals tickvals_y = list(range(0, max_y + tick_interval, tick_interval)) # Determine tick values for x-axis dynamically tickvals_x = binned_keys # Use the binned keys as tick values fig = go.Figure( data=[ go.Bar( x=binned_keys, y=binned_values, text=binned_values, textposition="outside", marker=dict(color="blue"), width=0.5, ) ] ) fig.update_layout( title={ "text": f"Histogram of {title}", "font": {"size": 30}, # Increase title font size "x": 0.5, # Center the title }, xaxis_title={"text": f"Number of {title}", "font": {"size": 24}}, # Increase x-axis label font size yaxis_title={"text": "Frequency", "font": {"size": 24}}, # Increase y-axis label font size xaxis=dict( tickmode="array", # Use custom tick values tickvals=tickvals_x, ticktext=[f"{x}-{x + bin_size - 1}" for x in binned_keys], tickfont=dict(size=20), # Increase x-axis tick font size ), yaxis=dict( tickvals=tickvals_y, # Set custom tick values ticktext=[str(val) for val in tickvals_y], # Set custom tick labels tickfont=dict(size=20), # Increase y-axis tick font size ), template="plotly_white", # bargap=0.5, # Add gap between bars (0.5 = 50% of bar width) # Increase figure width for a long x-axis width=max(600, 30 * len(binned_keys)), # Dynamic width based on number of bars ) # Save the figure as an image fig.write_image(output_path, scale=3) print(f"Figure saved to {output_path}") # fig.show() def loop_data(data_loader, eval_set, device, output_dir): max_num_points = -1 max_num_polys = -1 count_pts_dict = defaultdict(lambda: 0) count_room_dict = defaultdict(lambda: 0) count_length_dict = defaultdict(lambda: 0) for batched_inputs, batched_extras in data_loader: samples = [x["image"].to(device) for x in batched_inputs] gt_instances = [x["instances"].to(device) for x in batched_inputs] for i in range(len(samples)): if batched_extras is not None: t = (batched_extras["token_labels"][i] == 0).sum().item() count_length_dict[t] += 1 room_polys = gt_instances[i].gt_masks.polygons room_ids = gt_instances[i].gt_classes.detach().cpu().numpy() count_room_dict[len(room_ids)] += 1 for poly, poly_id in zip(room_polys, room_ids): poly = poly[0].reshape(-1, 2).astype(np.int32) count_pts_dict[len(poly)] += 1 if len(poly) > max_num_points: max_num_points = len(poly) if len(room_ids) > max_num_polys: max_num_polys = len(room_ids) print(f"Max pts: {max_num_points}, Max polys: {max_num_polys}") plot_histogram( count_pts_dict, "Points in Polygons", os.path.join(output_dir, f"{eval_set}_polygon_histogram.png"), bin_size=5 ) plot_histogram( count_room_dict, "Rooms in Floorplan image", os.path.join(output_dir, f"{eval_set}_room_histogram.png"), bin_size=5, ) plot_histogram( count_length_dict, "Corners in Floorplan image", os.path.join(output_dir, f"{eval_set}_seqlen_histogram.png"), bin_size=30, ) def get_args_parser(): parser = argparse.ArgumentParser("Raster2Seq plotting script", add_help=False) parser.add_argument("--batch_size", default=10, type=int) parser.add_argument("--debug", action="store_true") parser.add_argument("--image_size", type=int, default=256) parser.add_argument("--wd_only", action="store_true") parser.add_argument("--drop_wd", action="store_true", help="disable Windor & Door in the plots") parser.add_argument( "--crop_white_space", action="store_true", help="remove redundant whitespace from the rendering" ) parser.add_argument("--image_scale", type=int, default=1, help="adjust rendering resolution of the plots") parser.add_argument("--one_color", action="store_true", help="use single color for every room (i.e. yellow)") parser.add_argument("--is_bw", action="store_true", help="plot floorplan as binary image") parser.add_argument("--plot_engine", type=str, default="opencv") parser.add_argument( "--compute_stats", action="store_true", help="compute statistics of the dataset (e.g. max_num_pts, max_num_polys) " "and plot histogram for counting number of Points, Rooms, Corners", ) # poly2seq parser.add_argument("--poly2seq", action="store_true") parser.add_argument("--seq_len", type=int, default=1024) parser.add_argument("--num_bins", type=int, default=64) parser.add_argument("--add_cls_token", action="store_true") parser.add_argument("--per_token_sem_loss", action="store_true") # backbone parser.add_argument("--input_channels", default=1, type=int) parser.add_argument("--backbone", default="resnet50", type=str, help="Name of the convolutional backbone to use") parser.add_argument("--lr_backbone", default=0, type=float) parser.add_argument( "--dilation", action="store_true", help="If true, we replace stride with dilation in the last convolutional block (DC5)", ) parser.add_argument( "--position_embedding", default="sine", type=str, choices=("sine", "learned"), help="Type of positional embedding to use on top of the image features", ) parser.add_argument("--position_embedding_scale", default=2 * np.pi, type=float, help="position / size * scale") parser.add_argument("--num_feature_levels", default=4, type=int, help="number of feature levels") parser.add_argument("--image_norm", action="store_true") parser.add_argument("--disable_image_transform", action="store_true") # Transformer parser.add_argument("--enc_layers", default=6, type=int, help="Number of encoding layers in the transformer") parser.add_argument("--dec_layers", default=6, type=int, help="Number of decoding layers in the transformer") parser.add_argument( "--dim_feedforward", default=1024, type=int, help="Intermediate size of the feedforward layers in the transformer blocks", ) parser.add_argument( "--hidden_dim", default=256, type=int, help="Size of the embeddings (dimension of the transformer)" ) parser.add_argument("--dropout", default=0.1, type=float, help="Dropout applied in the transformer") parser.add_argument( "--nheads", default=8, type=int, help="Number of attention heads inside the transformer's attentions" ) parser.add_argument( "--num_queries", default=800, type=int, help="Number of query slots (num_polys * max. number of corner per poly)", ) parser.add_argument("--num_polys", default=20, type=int, help="Number of maximum number of room polygons") parser.add_argument("--dec_n_points", default=4, type=int) parser.add_argument("--enc_n_points", default=4, type=int) parser.add_argument( "--query_pos_type", default="sine", type=str, choices=("static", "sine", "none"), help="Type of query pos in decoder - \ 1. static: same setting with DETR and Deformable-DETR, the query_pos is the same for all layers \ 2. sine: since embedding from reference points (so if references points update, query_pos also \ 3. none: remove query_pos", ) parser.add_argument( "--with_poly_refine", default=True, action="store_true", help="iteratively refine reference points (i.e. positional part of polygon queries)", ) parser.add_argument( "--masked_attn", default=False, action="store_true", help="if true, the query in one room will not be allowed to attend other room", ) parser.add_argument( "--semantic_classes", default=-1, type=int, help="Number of classes for semantically-rich floorplan: \ 1. default -1 means non-semantic floorplan \ 2. 19 for Structured3D: 16 room types + 1 door + 1 window + 1 empty", ) parser.add_argument( "--use_room_attn_at_last_dec_layer", default=False, action="store_true", help="use room-wise attention in last decoder layer", ) # aux parser.add_argument( "--no_aux_loss", dest="aux_loss", action="store_true", help="Disables auxiliary decoding losses (loss at each layer)", ) # dataset parameters parser.add_argument("--dataset_name", default="stru3d") parser.add_argument("--dataset_root", default="data/stru3d", type=str) parser.add_argument("--eval_set", default="test", type=str) parser.add_argument("--device", default="cuda", help="device to use for training / testing") parser.add_argument("--num_workers", default=2, type=int) parser.add_argument("--seed", default=42, type=int) parser.add_argument("--checkpoint", default="checkpoints/roomformer_scenecad.pth", help="resume from checkpoint") parser.add_argument("--output_dir", default="eval_stru3d", help="path where to save result") # visualization options parser.add_argument( "--plot_density", default=False, action="store_true", help="plot predicited room polygons overlaid on the density map", ) parser.add_argument("--plot_gt", default=False, action="store_true", help="plot ground truth floorplan") parser.add_argument("--plot_gt_image", default=False, action="store_true", help="plot ground truth image") return parser def main(args): device = "cpu" # torch.device(args.device) # fix the seed for reproducibility seed = args.seed torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) # build dataset and dataloader dataset_eval = build_dataset(image_set=args.eval_set, args=args) # for test if args.debug: dataset_eval = torch.utils.data.Subset(dataset_eval, [7]) # list(range(0, args.batch_size, 1)) sampler_eval = torch.utils.data.SequentialSampler(dataset_eval) def trivial_batch_collator(batch): """ A batch collator that does nothing. """ if "target_seq" in batch[0]: # Concatenate tensors for each key in the batch delta_x1 = torch.stack([item["delta_x1"] for item in batch], dim=0) delta_x2 = torch.stack([item["delta_x2"] for item in batch], dim=0) delta_y1 = torch.stack([item["delta_y1"] for item in batch], dim=0) delta_y2 = torch.stack([item["delta_y2"] for item in batch], dim=0) seq11 = torch.stack([item["seq11"] for item in batch], dim=0) seq21 = torch.stack([item["seq21"] for item in batch], dim=0) seq12 = torch.stack([item["seq12"] for item in batch], dim=0) seq22 = torch.stack([item["seq22"] for item in batch], dim=0) target_seq = torch.stack([item["target_seq"] for item in batch], dim=0) token_labels = torch.stack([item["token_labels"] for item in batch], dim=0) mask = torch.stack([item["mask"] for item in batch], dim=0) # Delete the keys from the batch for item in batch: del item["delta_x1"] del item["delta_x2"] del item["delta_y1"] del item["delta_y2"] del item["seq11"] del item["seq21"] del item["seq12"] del item["seq22"] del item["target_seq"] del item["token_labels"] del item["mask"] # Return the concatenated batch return batch, { "delta_x1": delta_x1, "delta_x2": delta_x2, "delta_y1": delta_y1, "delta_y2": delta_y2, "seq11": seq11, "seq21": seq21, "seq12": seq12, "seq22": seq22, "target_seq": target_seq, "token_labels": token_labels, "mask": mask, } return batch, None data_loader_eval = DataLoader( dataset_eval, args.batch_size, sampler=sampler_eval, drop_last=False, collate_fn=trivial_batch_collator, num_workers=args.num_workers, pin_memory=True, ) output_dir = Path(args.output_dir) save_dir = output_dir # os.path.join(os.path.dirname(args.checkpoint), output_dir) os.makedirs(save_dir, exist_ok=True) if args.plot_gt: plot_gt_floor( args, data_loader_eval, device, save_dir, plot_gt=args.plot_gt, semantic_rich=args.semantic_classes > 0, crop_white_space=args.crop_white_space, ) if args.plot_density: plot_polys(data_loader_eval, device, save_dir) if args.plot_gt_image: plot_gt_image(data_loader_eval, device, save_dir, crop_white_space=args.crop_white_space) if args.compute_stats: loop_data(data_loader_eval, args.eval_set, device, save_dir) if __name__ == "__main__": parser = argparse.ArgumentParser("Raster2Seq plotting script", parents=[get_args_parser()]) args = parser.parse_args() main(args)