DeepFloorPlan / postprocess.py
rawanessam's picture
Upload 25 files
32fa779 verified
import argparse
import os
import sys
import glob
import time
import random
import numpy as np
import imageio
from PIL import Image
from matplotlib import pyplot as plt
sys.path.append('./utils/')
from rgb_ind_convertor import *
from util import *
parser = argparse.ArgumentParser()
parser.add_argument('--result_dir', type=str, default='./out',
help='The folder that save network predictions.')
def post_process(input_dir, save_dir, merge=True):
if not os.path.exists(save_dir):
os.mkdir(save_dir)
input_paths = sorted(glob.glob(os.path.join(input_dir, '*.png')))
names = [i.split('/')[-1] for i in input_paths]
out_paths = [os.path.join(save_dir, i) for i in names]
n = len(input_paths)
# n = 1
for i in range(n):
im = imageio.imread(input_paths[i], mode='RGB')
im_ind = rgb2ind(im, color_map=floorplan_fuse_map)
# seperate image into room-seg & boundary-seg
rm_ind = im_ind.copy()
rm_ind[im_ind==9] = 0
rm_ind[im_ind==10] = 0
bd_ind = np.zeros(im_ind.shape, dtype=np.uint8)
bd_ind[im_ind==9] = 9
bd_ind[im_ind==10] = 10
hard_c = (bd_ind>0).astype(np.uint8)
# region from room prediction it self
rm_mask = np.zeros(rm_ind.shape)
rm_mask[rm_ind>0] = 1
# region from close_wall line
cw_mask = hard_c
# refine close wall mask by filling the grap between bright line
cw_mask = fill_break_line(cw_mask)
fuse_mask = cw_mask + rm_mask
fuse_mask[fuse_mask>=1] = 255
# refine fuse mask by filling the hole
fuse_mask = flood_fill(fuse_mask)
fuse_mask = fuse_mask // 255
# one room one label
new_rm_ind = refine_room_region(cw_mask, rm_ind)
# ignore the background mislabeling
new_rm_ind = fuse_mask*new_rm_ind
print('Saving {}th refined room prediciton to {}'.format(i, out_paths[i]))
if merge:
new_rm_ind[bd_ind==9] = 9
new_rm_ind[bd_ind==10] = 10
rgb = ind2rgb(new_rm_ind, color_map=floorplan_fuse_map)
else:
rgb = ind2rgb(new_rm_ind)
imageio.imwrite(out_paths[i], rgb)
if __name__ == '__main__':
FLAGS, unparsed = parser.parse_known_args()
input_dir = FLAGS.result_dir
save_dir = os.path.join(input_dir, 'post')
post_process(input_dir, save_dir)