Spaces:
Runtime error
Runtime error
| import argparse | |
| import copy | |
| import json | |
| import os | |
| import random | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from torch.utils.data import DataLoader, Dataset | |
| from tqdm import tqdm, trange | |
| from datasets.discrete_tokenizer import DiscreteTokenizer | |
| from datasets.transforms import ResizeAndPad | |
| from detectron2.data import transforms as T | |
| from engine import generate, plot_density_map | |
| from models import build_model | |
| from util.plot_utils import CC5K_LABEL, S3D_LABEL, auto_crop_whitespace, plot_semantic_rich_floorplan_opencv | |
| class ImageDataset(Dataset): | |
| def __init__(self, image_paths, num_image_channels=3, transform=None): | |
| """ | |
| Args: | |
| image_paths (list): List of image file paths. | |
| transform (callable, optional): Optional transform to be applied on an image. | |
| """ | |
| self.image_paths = image_paths | |
| self.transform = transform | |
| self.num_image_channels = num_image_channels | |
| def __len__(self): | |
| return len(self.image_paths) | |
| def _expand_image_dims(self, x): | |
| if len(x.shape) == 2: | |
| exp_img = np.expand_dims(x, 0) | |
| else: | |
| exp_img = x.transpose((2, 0, 1)) # (h,w,c) -> (c,h,w) | |
| return exp_img | |
| def __getitem__(self, idx): | |
| """ | |
| Args: | |
| idx (int): Index of the image to fetch. | |
| Returns: | |
| torch.Tensor: Transformed image tensor. | |
| """ | |
| img_path = self.image_paths[idx] | |
| if self.num_image_channels == 3: | |
| image = np.array(Image.open(img_path).convert("RGB")) # Ensure 3-channel RGB | |
| else: | |
| image = np.array(Image.open(img_path)) # Ensure 1-channel RGB | |
| if self.transform: | |
| aug_input = T.AugInput(image) | |
| _ = self.transform(aug_input) | |
| image = aug_input.image | |
| image = (1 / 255) * torch.as_tensor(np.array(self._expand_image_dims(image))) | |
| return { | |
| "file_name": img_path, | |
| "image": image, | |
| } | |
| def get_args_parser(): | |
| parser = argparse.ArgumentParser("Raster2Seq prediction script", add_help=False) | |
| parser.add_argument("--batch_size", default=10, type=int) | |
| parser.add_argument("--debug", action="store_true") | |
| parser.add_argument("--input_channels", default=1, type=int) | |
| parser.add_argument("--image_norm", action="store_true") | |
| parser.add_argument("--eval_every_epoch", type=int, default=20) | |
| parser.add_argument("--ckpt_every_epoch", type=int, default=20) | |
| parser.add_argument("--label_smoothing", type=float, default=0.0) | |
| parser.add_argument("--ignore_index", type=int, default=-1) | |
| parser.add_argument("--image_size", type=int, default=256) | |
| parser.add_argument("--ema4eval", action="store_true") | |
| parser.add_argument("--measure_time", action="store_true") | |
| parser.add_argument("--disable_sampling_cache", action="store_true") | |
| parser.add_argument("--use_anchor", action="store_true") | |
| parser.add_argument("--drop_wd", action="store_true") | |
| parser.add_argument("--plot_text", action="store_true") | |
| parser.add_argument("--image_scale", type=int, default=2) | |
| parser.add_argument("--one_color", action="store_true") | |
| parser.add_argument("--crop_white_space", action="store_true") | |
| # raster2seq | |
| parser.add_argument("--poly2seq", action="store_true") | |
| parser.add_argument("--seq_len", type=int, default=1024) | |
| parser.add_argument("--num_bins", type=int, default=64) | |
| parser.add_argument("--pre_decoder_pos_embed", action="store_true") | |
| parser.add_argument("--learnable_dec_pe", action="store_true") | |
| parser.add_argument("--dec_qkv_proj", action="store_true") | |
| parser.add_argument("--dec_attn_concat_src", action="store_true") | |
| parser.add_argument("--per_token_sem_loss", action="store_true") | |
| parser.add_argument("--add_cls_token", action="store_true") | |
| # backbone | |
| parser.add_argument("--backbone", default="resnet50", type=str, help="Name of the convolutional backbone to use") | |
| parser.add_argument("--lr_backbone", default=0, type=float) | |
| parser.add_argument( | |
| "--dilation", | |
| action="store_true", | |
| help="If true, we replace stride with dilation in the last convolutional block (DC5)", | |
| ) | |
| parser.add_argument( | |
| "--position_embedding", | |
| default="sine", | |
| type=str, | |
| choices=("sine", "learned"), | |
| help="Type of positional embedding to use on top of the image features", | |
| ) | |
| parser.add_argument("--position_embedding_scale", default=2 * np.pi, type=float, help="position / size * scale") | |
| parser.add_argument("--num_feature_levels", default=4, type=int, help="number of feature levels") | |
| # Transformer | |
| parser.add_argument("--enc_layers", default=6, type=int, help="Number of encoding layers in the transformer") | |
| parser.add_argument("--dec_layers", default=6, type=int, help="Number of decoding layers in the transformer") | |
| parser.add_argument( | |
| "--dim_feedforward", | |
| default=1024, | |
| type=int, | |
| help="Intermediate size of the feedforward layers in the transformer blocks", | |
| ) | |
| parser.add_argument( | |
| "--hidden_dim", default=256, type=int, help="Size of the embeddings (dimension of the transformer)" | |
| ) | |
| parser.add_argument("--dropout", default=0.1, type=float, help="Dropout applied in the transformer") | |
| parser.add_argument( | |
| "--nheads", default=8, type=int, help="Number of attention heads inside the transformer's attentions" | |
| ) | |
| parser.add_argument( | |
| "--num_queries", | |
| default=800, | |
| type=int, | |
| help="Number of query slots (num_polys * max. number of corner per poly)", | |
| ) | |
| parser.add_argument("--num_polys", default=20, type=int, help="Number of maximum number of room polygons") | |
| parser.add_argument("--dec_n_points", default=4, type=int) | |
| parser.add_argument("--enc_n_points", default=4, type=int) | |
| parser.add_argument( | |
| "--query_pos_type", | |
| default="sine", | |
| type=str, | |
| choices=("static", "sine", "none"), | |
| help="Type of query pos in decoder - \ | |
| 1. static: same setting with DETR and Deformable-DETR, the query_pos is the same for all layers \ | |
| 2. sine: since embedding from reference points (so if references points update, query_pos also \ | |
| 3. none: remove query_pos", | |
| ) | |
| parser.add_argument( | |
| "--with_poly_refine", | |
| default=True, | |
| action="store_true", | |
| help="iteratively refine reference points (i.e. positional part of polygon queries)", | |
| ) | |
| parser.add_argument( | |
| "--masked_attn", | |
| default=False, | |
| action="store_true", | |
| help="if true, the query in one room will not be allowed to attend other room", | |
| ) | |
| parser.add_argument( | |
| "--semantic_classes", | |
| default=-1, | |
| type=int, | |
| help="Number of classes for semantically-rich floorplan: \ | |
| 1. default -1 means non-semantic floorplan \ | |
| 2. 19 for Structured3D: 16 room types + 1 door + 1 window + 1 empty", | |
| ) | |
| parser.add_argument( | |
| "--disable_poly_refine", | |
| action="store_true", | |
| help="iteratively refine reference points (i.e. positional part of polygon queries)", | |
| ) | |
| # aux | |
| parser.add_argument( | |
| "--no_aux_loss", | |
| dest="aux_loss", | |
| action="store_true", | |
| help="Disables auxiliary decoding losses (loss at each layer)", | |
| ) | |
| # dataset parameters | |
| parser.add_argument("--dataset_name", default="stru3d") | |
| parser.add_argument("--dataset_root", default="data/stru3d", type=str) | |
| parser.add_argument("--eval_set", default="test", type=str) | |
| parser.add_argument("--device", default="cuda", help="device to use for training / testing") | |
| parser.add_argument("--num_workers", default=2, type=int) | |
| parser.add_argument("--seed", default=42, type=int) | |
| parser.add_argument("--checkpoint", default="checkpoints/roomformer_scenecad.pth", help="resume from checkpoint") | |
| parser.add_argument("--output_dir", default="eval_stru3d", help="path where to save result") | |
| # visualization options | |
| parser.add_argument("--plot_pred", default=True, type=bool, help="plot predicted floorplan") | |
| parser.add_argument( | |
| "--plot_density", default=True, type=bool, help="plot predicited room polygons overlaid on the density map" | |
| ) | |
| parser.add_argument("--plot_gt", default=False, type=bool, help="plot ground truth floorplan") | |
| parser.add_argument("--save_pred", action="store_true", help="save_pred") | |
| return parser | |
| def get_image_paths_from_directory(directory_path): | |
| """ | |
| Load all images from the specified directory. | |
| Args: | |
| directory_path (str): Path to the directory containing images. | |
| Returns: | |
| list: A list of PIL Image objects. | |
| """ | |
| paths = [] | |
| valid_extensions = (".jpg", ".jpeg", ".png", ".bmp", ".tiff") # Add more extensions if needed | |
| # Iterate through all files in the directory | |
| for root, _, files in os.walk(directory_path): | |
| for filename in files: | |
| if filename.lower().endswith(valid_extensions): # Check for valid image extensions | |
| file_path = os.path.join(root, filename) | |
| paths.append(file_path) | |
| return paths | |
| def main(args): | |
| device = torch.device(args.device) | |
| # fix the seed for reproducibility | |
| seed = args.seed | |
| torch.manual_seed(seed) | |
| np.random.seed(seed) | |
| random.seed(seed) | |
| image_paths = get_image_paths_from_directory(args.dataset_root) | |
| data_transform = T.AugmentationList( | |
| [ | |
| ResizeAndPad((args.image_size, args.image_size), pad_value=255), | |
| ] | |
| ) | |
| dataset_eval = ImageDataset(image_paths, num_image_channels=args.input_channels, transform=data_transform) | |
| tokenizer = None | |
| if args.poly2seq: | |
| tokenizer = DiscreteTokenizer(args.num_bins, args.seq_len, add_cls=args.add_cls_token) | |
| args.vocab_size = len(tokenizer) | |
| # overfit one sample | |
| if args.debug: | |
| idx = 0 | |
| for i, x in enumerate(dataset_eval): | |
| if "3252" in x["file_name"]: | |
| idx = i | |
| dataset_eval = torch.utils.data.Subset(dataset_eval, [idx]) | |
| sampler_eval = torch.utils.data.SequentialSampler(dataset_eval) | |
| data_loader_eval = DataLoader( | |
| dataset_eval, | |
| args.batch_size, | |
| sampler=sampler_eval, | |
| drop_last=False, | |
| num_workers=args.num_workers, | |
| pin_memory=True, | |
| ) | |
| # build model | |
| model = build_model(args, train=False, tokenizer=tokenizer) | |
| model.to(device) | |
| n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| print("number of params:", n_parameters) | |
| checkpoint = torch.load(args.checkpoint, map_location="cpu") | |
| if args.ema4eval: | |
| ckpt_state_dict = copy.deepcopy(checkpoint["ema"]) | |
| else: | |
| ckpt_state_dict = copy.deepcopy(checkpoint["model"]) | |
| for key, value in checkpoint["model"].items(): | |
| if key.startswith("module."): | |
| ckpt_state_dict[key[7:]] = checkpoint["model"][key] | |
| del ckpt_state_dict[key] | |
| missing_keys, unexpected_keys = model.load_state_dict(ckpt_state_dict, strict=False) | |
| unexpected_keys = [k for k in unexpected_keys if not (k.endswith("total_params") or k.endswith("total_ops"))] | |
| if len(missing_keys) > 0: | |
| print("Missing Keys: {}".format(missing_keys)) | |
| if len(unexpected_keys) > 0: | |
| print("Unexpected Keys: {}".format(unexpected_keys)) | |
| # disable grad | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| save_dir = os.path.join(args.output_dir, os.path.dirname(args.checkpoint).split("/")[-1]) | |
| os.makedirs(save_dir, exist_ok=True) | |
| semantics_label_mapping = None | |
| if args.dataset_name == "stru3d": | |
| door_window_index = [16, 17] | |
| semantics_label_mapping = S3D_LABEL | |
| elif args.dataset_name == "cubicasa": | |
| door_window_index = [10, 9] | |
| semantics_label_mapping = CC5K_LABEL | |
| elif args.dataset_name == "waffle": | |
| door_window_index = [1, 2] | |
| else: | |
| door_window_index = [] | |
| if args.measure_time: | |
| images = torch.rand(args.batch_size, 3, args.image_size, args.image_size).to(device) | |
| if args.poly2seq: | |
| model = torch.compile(model) # compile model is not compatible with RoomFormer | |
| # GPU-WARM-UP | |
| for _ in trange(10, desc="GPU-WARM-UP"): | |
| if not args.poly2seq: | |
| _ = model(images) | |
| else: | |
| _ = model.forward_inference(images) | |
| # INIT LOGGERS | |
| starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) | |
| total_time = 0.0 | |
| for batch_images in tqdm(data_loader_eval): | |
| starter.record() | |
| x = batch_images["image"].to(device) | |
| filenames = batch_images["file_name"] | |
| outputs = generate( | |
| model, | |
| x, | |
| semantic_rich=args.semantic_classes > 0, | |
| use_cache=True, | |
| per_token_sem_loss=args.per_token_sem_loss, | |
| drop_wd=args.drop_wd, | |
| poly2seq=args.poly2seq, | |
| ) | |
| ender.record() | |
| torch.cuda.synchronize() | |
| total_time += starter.elapsed_time(ender) / len(data_loader_eval) | |
| pred_rooms = outputs["room"] | |
| pred_labels = outputs["labels"] | |
| image_size = x.shape[-2] | |
| for j, (pred_rm, pred_cls) in enumerate(zip(pred_rooms, pred_labels)): | |
| if pred_cls is None: | |
| pred_cls = [-1] * len(pred_rm) | |
| fn = os.path.basename(filenames[j]).split(".")[0] | |
| pred_room_map = plot_density_map( | |
| x[j], | |
| image_size, | |
| pred_rm, | |
| pred_cls, | |
| plot_text=args.plot_text, | |
| ) | |
| floorplan_map = plot_semantic_rich_floorplan_opencv( | |
| zip(pred_rm, pred_cls), | |
| None, | |
| door_window_index=door_window_index, | |
| semantics_label_mapping=semantics_label_mapping, | |
| plot_text=args.plot_text, | |
| one_color=args.one_color, | |
| is_sem=args.semantic_classes > 0, | |
| img_w=image_size * args.image_scale, | |
| img_h=image_size * args.image_scale, | |
| scale=args.image_scale, | |
| ) | |
| image = x[j].permute(1, 2, 0).cpu().numpy() * 255 | |
| if args.crop_white_space: | |
| image = cv2.resize( | |
| image, | |
| (args.image_scale * args.image_size, args.image_scale * args.image_size), | |
| interpolation=cv2.INTER_NEAREST, | |
| ) | |
| image, cropped_box = auto_crop_whitespace(image) | |
| _x, _y, _w, _h = [ele for ele in cropped_box] | |
| floorplan_map = floorplan_map[_y : _y + _h, _x : _x + _w].copy() | |
| # Ensure images are not empty before saving | |
| if pred_room_map is not None and pred_room_map.size > 0: | |
| cv2.imwrite(os.path.join(save_dir, "{}_pred_room_map.png".format(fn)), pred_room_map) | |
| else: | |
| print("Warning: pred_room_map is empty, skipping save.") | |
| if floorplan_map is not None and floorplan_map.size > 0: | |
| cv2.imwrite(os.path.join(save_dir, "{}_pred_floorplan.png".format(fn)), floorplan_map) | |
| else: | |
| print("Warning: floorplan_map is empty, skipping save.") | |
| if image is not None and image.size > 0: | |
| cv2.imwrite(os.path.join(save_dir, "{}.png".format(fn)), image) | |
| else: | |
| print("Warning: image is empty, skipping save.") | |
| if args.save_pred: | |
| # Save room_polys as JSON | |
| json_path = os.path.join(save_dir, "jsons", "{}.json".format(fn)) | |
| npy_path = os.path.join(save_dir, "npy", "{}.npy".format(fn)) | |
| os.makedirs(os.path.dirname(json_path), exist_ok=True) | |
| os.makedirs(os.path.dirname(npy_path), exist_ok=True) | |
| polys_list = [poly.astype(float).tolist() for poly in pred_rm] | |
| types_list = pred_cls | |
| output_json = [ | |
| { | |
| "image_id": fn, | |
| "segmentation": polys_list[instance_id], | |
| "category_id": int(types_list[instance_id]), | |
| "id": instance_id, | |
| } | |
| for instance_id in range(len(polys_list)) | |
| ] | |
| with open(json_path, "w") as json_file: | |
| json.dump(output_json, json_file) | |
| polys_list = [np.array(poly).reshape(-1, 2) for poly in polys_list] | |
| np.save(npy_path, np.array(polys_list, dtype=object)) | |
| print(f"Total inference time: {total_time:.2f} ms") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser("Raster2Seq prediction script", parents=[get_args_parser()]) | |
| args = parser.parse_args() | |
| if args.debug: | |
| args.batch_size = 1 | |
| if args.disable_poly_refine: | |
| args.with_poly_refine = False | |
| main(args) | |