Spaces:
Runtime error
Runtime error
| import matplotlib.path as mplp | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from matplotlib import cm, colors | |
| from shapely.geometry import Point, Polygon | |
| from skimage import draw | |
| def discrete_cmap_furukawa(): | |
| """create a colormap with N (N<15) discrete colors and register it""" | |
| # define individual colors as hex values | |
| cpool = [ | |
| "#696969", | |
| "#b3de69", | |
| "#ffffb3", | |
| "#8dd3c7", | |
| "#fdb462", | |
| "#fccde5", | |
| "#80b1d3", | |
| "#d9d9d9", | |
| "#fb8072", | |
| "#577a4d", | |
| "white", | |
| "#000000", | |
| "#e31a1c", | |
| ] | |
| cmap3 = colors.ListedColormap(cpool, "rooms_furukawa") | |
| cm.register_cmap(cmap=cmap3) | |
| cpool = [ | |
| "#ede676", | |
| "#8dd3c7", | |
| "#b15928", | |
| "#fdb462", | |
| "#ffff99", | |
| "#fccde5", | |
| "#80b1d3", | |
| "#d9d9d9", | |
| "#fb8072", | |
| "#696969", | |
| "#577a4d", | |
| "#e31a1c", | |
| "#42ef59", | |
| "#8c595a", | |
| "#3131e5", | |
| "#48e0e6", | |
| "white", | |
| ] | |
| cmap3 = colors.ListedColormap(cpool, "icons_furukawa") | |
| cm.register_cmap(cmap=cmap3) | |
| def drawJunction(h, point, point_type, width, height): | |
| lineLength = 15 | |
| lineWidth = 10 | |
| x, y = point | |
| # plt.text(x,y,str(index),fontsize=25,color='r') | |
| if point_type == -1: | |
| h.scatter(x, y, color="#6488ea") | |
| ########################### | |
| # o | |
| # | #6488ea soft blue | |
| # | drawcode = [1,1] | |
| # | |
| ########################### | |
| if point_type == 0: | |
| h.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="#6488ea") | |
| # plt.scatter(x, y-10, c='k') | |
| ########################### | |
| # | |
| # ---o #6241c7 bluey purple | |
| # drawcode = [1,2] | |
| # | |
| ########################### | |
| elif point_type == 1: | |
| h.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="#6241c7") | |
| # plt.scatter(x+10, y, c='k') | |
| ########################### | |
| # | | |
| # | drawcode = [1,3] | |
| # o #056eee cerulean blue | |
| # | |
| ########################### | |
| elif point_type == 2: | |
| h.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="#056eee") | |
| # plt.scatter(x, y+10, c='k') | |
| ########################### | |
| # | |
| # drawcode = [1,4] | |
| # | |
| # o--- #004577 prussian blue | |
| # | |
| ########################### | |
| elif point_type == 3: | |
| h.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="#004577") | |
| # plt.scatter(x-10, y, c='k') | |
| ########################### | |
| # | |
| # |--- drawcode = [2,3] | |
| # | | |
| # | |
| ########################### | |
| elif point_type == 6: | |
| h.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="#04d8b2") | |
| h.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="#04d8b2") | |
| ########################### | |
| # | |
| # ---| | |
| # | drawcode = [2,4] | |
| # | |
| ########################### | |
| elif point_type == 7: | |
| h.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="#cdfd02") | |
| h.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="#cdfd02") | |
| ########################### | |
| # | | |
| # ---| drawcode = [2,1] | |
| # | |
| # | |
| ########################### | |
| elif point_type == 4: | |
| h.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="#ff81c0") | |
| h.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="#ff81c0") | |
| ########################### | |
| # | |
| # | | |
| # | drawcode = [2,2] | |
| # -- | |
| # | |
| ########################### | |
| elif point_type == 5: | |
| h.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="#f97306") | |
| h.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="#f97306") | |
| ########################### | |
| # | |
| # | | |
| # |--- drawcode = [3,4] | |
| # | | |
| # | |
| ########################### | |
| elif point_type == 11: | |
| h.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="b") | |
| h.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="b") | |
| h.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="b") | |
| ########################### | |
| # | |
| # --- | |
| # | drawcode = [3,1] | |
| # | | |
| # | |
| ########################### | |
| elif point_type == 8: | |
| h.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="y") | |
| h.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="y") | |
| h.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="y") | |
| ########################### | |
| # | |
| # | | |
| # ---| drawcode = [3,2] | |
| # | | |
| # | |
| ########################### | |
| elif point_type == 9: | |
| h.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="r") | |
| h.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="r") | |
| h.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="r") | |
| ########################### | |
| # | |
| # | | |
| # | drawcode = [3,3] | |
| # --- | |
| # | |
| ########################### | |
| elif point_type == 10: | |
| h.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="m") | |
| h.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="m") | |
| h.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="m") | |
| ########################### | |
| # | |
| # | | |
| # --- drawcode = [4,1] | |
| # | | |
| # | |
| ########################### | |
| elif point_type == 12: | |
| h.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="k") | |
| h.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="k") | |
| h.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="k") | |
| h.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="k") | |
| lineLength = 10 | |
| lineWidth = 5 | |
| ########################### | |
| # o--- opening left | |
| ########################### | |
| if point_type == 13: | |
| h.plot([x], [y], "o", markersize=30, color="red") | |
| h.plot([x], [y], "o", markersize=25, color="white") | |
| h.text(x, y, "OL", fontsize=30, color="magenta") | |
| ########################### | |
| # ---o opening right | |
| ########################### | |
| elif point_type == 14: | |
| h.plot([x], [y], "o", markersize=30, color="red") | |
| h.plot([x], [y], "o", markersize=25, color="white") | |
| h.text(x, y, "OR", fontsize=30, color="magenta") | |
| ########################### | |
| # o opening up | |
| # | | |
| # | | |
| ########################### | |
| elif point_type == 15: | |
| h.plot([x], [y], "o", markersize=30, color="red") | |
| h.plot([x], [y], "o", markersize=25, color="white") | |
| h.text(x, y, "OU", fontsize=30, color="mediumblue") | |
| ########################### | |
| # | opening down | |
| # | | |
| # o | |
| ########################### | |
| elif point_type == 16: | |
| h.plot([x], [y], "o", markersize=30, color="red") | |
| h.plot([x], [y], "o", markersize=25, color="white") | |
| h.text(x, y, "OD", fontsize=30, color="mediumblue") | |
| ########################### | |
| # | |
| # |--- drawcode = [2,3] | |
| # | | |
| # | |
| ########################### | |
| elif point_type == 17: | |
| h.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="indianred") | |
| h.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="indianred") | |
| ########################### | |
| # | |
| # ---| | |
| # | drawcode = [2,4] | |
| # | |
| ########################### | |
| elif point_type == 18: | |
| h.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="darkred") | |
| h.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="darkred") | |
| ########################### | |
| # | |
| # | | |
| # | drawcode = [2,2] | |
| # -- | |
| # | |
| ########################### | |
| elif point_type == 19: | |
| h.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="salmon") | |
| h.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="salmon") | |
| ########################### | |
| # | | |
| # ---| drawcode = [2,1] | |
| # | |
| # | |
| ########################### | |
| elif point_type == 20: | |
| h.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="orangered") | |
| h.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="orangered") | |
| def draw_junction_from_dict(point_dict, width, height, size=1, fontsize=30): | |
| index = 0 | |
| markersize_large = 20 * size | |
| markersize_small = 15 * size | |
| for point_type, locations in point_dict.items(): | |
| for loc in locations: | |
| x, y = loc | |
| lineLength = 20 * size | |
| lineWidth = 20 * size | |
| # plt.text(x,y,str(index),fontsize=25,color='r') | |
| ########################### | |
| # o | |
| # | #6488ea soft blue | |
| # | drawcode = [1,1] | |
| # | |
| ########################### | |
| if point_type == 0: | |
| plt.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="#6488ea") | |
| # plt.scatter(x, y-10, c='k') | |
| ########################### | |
| # | |
| # ---o #6241c7 bluey purple | |
| # drawcode = [1,2] | |
| # | |
| ########################### | |
| elif point_type == 1: | |
| plt.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="#6241c7") | |
| # plt.scatter(x+10, y, c='k') | |
| ########################### | |
| # | | |
| # | drawcode = [1,3] | |
| # o #056eee cerulean blue | |
| # | |
| ########################### | |
| elif point_type == 2: | |
| plt.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="#056eee") | |
| # plt.scatter(x, y+10, c='k') | |
| ########################### | |
| # | |
| # drawcode = [1,4] | |
| # | |
| # o--- #004577 prussian blue | |
| # | |
| ########################### | |
| elif point_type == 3: | |
| plt.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="#004577") | |
| # plt.scatter(x-10, y, c='k') | |
| ########################### | |
| # | |
| # |--- drawcode = [2,3] | |
| # | | |
| # | |
| ########################### | |
| elif point_type == 6: | |
| plt.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="#04d8b2") | |
| plt.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="#04d8b2") | |
| ########################### | |
| # | |
| # ---| | |
| # | drawcode = [2,4] | |
| # | |
| ########################### | |
| elif point_type == 7: | |
| plt.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="#cdfd02") | |
| plt.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="#cdfd02") | |
| ########################### | |
| # | | |
| # ---| drawcode = [2,1] | |
| # | |
| # | |
| ########################### | |
| elif point_type == 4: | |
| plt.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="#ff81c0") | |
| plt.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="#ff81c0") | |
| ########################### | |
| # | |
| # | | |
| # | drawcode = [2,2] | |
| # -- | |
| # | |
| ########################### | |
| elif point_type == 5: | |
| plt.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="#f97306") | |
| plt.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="#f97306") | |
| ########################### | |
| # | |
| # | | |
| # |--- drawcode = [3,4] | |
| # | | |
| # | |
| ########################### | |
| elif point_type == 11: | |
| plt.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="b") | |
| plt.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="b") | |
| plt.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="b") | |
| ########################### | |
| # | |
| # --- | |
| # | drawcode = [3,1] | |
| # | | |
| # | |
| ########################### | |
| elif point_type == 8: | |
| plt.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="y") | |
| plt.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="y") | |
| plt.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="y") | |
| ########################### | |
| # | |
| # | | |
| # ---| drawcode = [3,2] | |
| # | | |
| # | |
| ########################### | |
| elif point_type == 9: | |
| plt.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="r") | |
| plt.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="r") | |
| plt.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="r") | |
| ########################### | |
| # | |
| # | | |
| # | drawcode = [3,3] | |
| # --- | |
| # | |
| ########################### | |
| elif point_type == 10: | |
| plt.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="m") | |
| plt.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="m") | |
| plt.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="m") | |
| ########################### | |
| # | |
| # | | |
| # --- drawcode = [4,1] | |
| # | | |
| # | |
| ########################### | |
| elif point_type == 12: | |
| plt.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="k") | |
| plt.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="k") | |
| plt.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="k") | |
| plt.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="k") | |
| lineLength = 15 * size | |
| lineWidth = 15 * size | |
| ########################### | |
| # o--- opening left | |
| ########################### | |
| if point_type == 13: | |
| plt.plot([x], [y], "o", markersize=markersize_large, color="red") | |
| plt.plot([x], [y], "o", markersize=markersize_small, color="white") | |
| plt.text(x, y, "OL", fontsize=fontsize, color="magenta") | |
| ########################### | |
| # ---o opening right | |
| ########################### | |
| elif point_type == 14: | |
| plt.plot([x], [y], "o", markersize=markersize_large, color="red") | |
| plt.plot([x], [y], "o", markersize=markersize_small, color="white") | |
| plt.text(x, y, "OR", fontsize=fontsize, color="magenta") | |
| ########################### | |
| # o opening up | |
| # | | |
| # | | |
| ########################### | |
| elif point_type == 15: | |
| plt.plot([x], [y], "o", markersize=markersize_large, color="red") | |
| plt.plot([x], [y], "o", markersize=markersize_small, color="white") | |
| plt.text(x, y, "OU", fontsize=fontsize, color="mediumblue") | |
| ########################### | |
| # | opening down | |
| # | | |
| # o | |
| ########################### | |
| elif point_type == 16: | |
| plt.plot([x], [y], "o", markersize=markersize_large, color="red") | |
| plt.plot([x], [y], "o", markersize=markersize_small, color="white") | |
| plt.text(x, y, "OD", fontsize=fontsize, color="mediumblue") | |
| ########################### | |
| # | |
| # |--- drawcode = [2,3] | |
| # | | |
| # | |
| ########################### | |
| elif point_type == 17: | |
| plt.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="indianred") | |
| plt.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="indianred") | |
| ########################### | |
| # | |
| # ---| | |
| # | drawcode = [2,4] | |
| # | |
| ########################### | |
| elif point_type == 18: | |
| plt.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="darkred") | |
| plt.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="darkred") | |
| ########################### | |
| # | |
| # | | |
| # | drawcode = [2,2] | |
| # -- | |
| # | |
| ########################### | |
| elif point_type == 19: | |
| plt.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="salmon") | |
| plt.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="salmon") | |
| ########################### | |
| # | | |
| # ---| drawcode = [2,1] | |
| # | |
| # | |
| ########################### | |
| elif point_type == 20: | |
| plt.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="orangered") | |
| plt.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="orangered") | |
| index += 1 | |
| def plot_pre_rec_4(instances, classes): | |
| walls = ["Wall", "Railing"] | |
| openings = ["Window", "Door"] | |
| rooms = [ | |
| "Outdoor", | |
| "Kitchen", | |
| "Living Room", | |
| "Bed Room", | |
| "Entry", | |
| "Dining", | |
| "Storage", | |
| "Garage", | |
| "Undefined Room", | |
| "Sauna", | |
| "Fire Place", | |
| "Bathtub", | |
| "Chimney", | |
| ] | |
| icons = [ | |
| "Bath", | |
| "Closet", | |
| "Electrical Appliance", | |
| "Toilet", | |
| "Shower", | |
| "Sink", | |
| "Sauna", | |
| "Fire Place", | |
| "Bathtub", | |
| "Chimney", | |
| ] | |
| def make_sub_plot(classes_to_plot): | |
| plt.ylim([0.0, 1.0]) | |
| plt.xlim([0.0, 1.0]) | |
| plt.xlabel("Recall") | |
| plt.ylabel("Precision") | |
| indx = [classes.index(i) for i in classes_to_plot] | |
| ins = instances[:, indx].sum(axis=1) | |
| correct = ins[:, 0] | |
| false_positive = ins[:, 2] | |
| false_negatives = ins[:, 1] | |
| precision = correct / (correct + false_positive) | |
| recall = correct / (correct + false_negatives) | |
| plt.step(recall[::-1], precision, color="b", alpha=0.2, where="post") | |
| plt.fill_between(recall[::-1], precision, step="post", alpha=0.2, color="b") | |
| plt.subplot(2, 2, 1) | |
| plt.title("Walls") | |
| make_sub_plot(walls) | |
| plt.subplot(2, 2, 2) | |
| plt.title("Openings") | |
| make_sub_plot(openings) | |
| plt.subplot(2, 2, 3) | |
| plt.title("Rooms") | |
| make_sub_plot(rooms) | |
| plt.subplot(2, 2, 4) | |
| plt.title("Icons") | |
| make_sub_plot(icons) | |
| def discrete_cmap(): | |
| """create a colormap with N (N<15) discrete colors and register it""" | |
| # define individual colors as hex values | |
| cpool = [ | |
| "#DCDCDC", | |
| "#b3de69", | |
| "#000000", | |
| "#8dd3c7", | |
| "#fdb462", | |
| "#fccde5", | |
| "#80b1d3", | |
| "#808080", | |
| "#fb8072", | |
| "#696969", | |
| "#577a4d", | |
| "#ffffb3", | |
| ] | |
| cmap3 = colors.ListedColormap(cpool, "rooms") | |
| cm.register_cmap(cmap=cmap3) | |
| cpool = [ | |
| "#DCDCDC", | |
| "#8dd3c7", | |
| "#b15928", | |
| "#fdb462", | |
| "#ffff99", | |
| "#fccde5", | |
| "#80b1d3", | |
| "#808080", | |
| "#fb8072", | |
| "#696969", | |
| "#577a4d", | |
| ] | |
| cmap3 = colors.ListedColormap(cpool, "icons") | |
| cm.register_cmap(cmap=cmap3) | |
| """create a colormap with N (N<15) discrete colors and register it""" | |
| # define individual colors as hex values | |
| cpool = [ | |
| "#DCDCDC", | |
| "#b3de69", | |
| "#000000", | |
| "#8dd3c7", | |
| "#fdb462", | |
| "#fccde5", | |
| "#80b1d3", | |
| "#808080", | |
| "#fb8072", | |
| "#696969", | |
| "#577a4d", | |
| "#ffffb3", | |
| "d3d5d7", | |
| ] | |
| cmap3 = colors.ListedColormap(cpool, "rooms_furu") | |
| cm.register_cmap(cmap=cmap3) | |
| cpool = [ | |
| "#DCDCDC", | |
| "#8dd3c7", | |
| "#b15928", | |
| "#fdb462", | |
| "#ffff99", | |
| "#fccde5", | |
| "#80b1d3", | |
| "#808080", | |
| "#fb8072", | |
| "#696969", | |
| "#577a4d", | |
| ] | |
| cmap3 = colors.ListedColormap(cpool, "rooms_furu") | |
| cm.register_cmap(cmap=cmap3) | |
| def segmentation_plot(rooms_pred, icons_pred, rooms_label, icons_label): | |
| room_classes = [ | |
| "Background", | |
| "Outdoor", | |
| "Wall", | |
| "Kitchen", | |
| "Living Room", | |
| "Bed Room", | |
| "Bath", | |
| "Entry", | |
| "Railing", | |
| "Storage", | |
| "Garage", | |
| "Undefined", | |
| ] | |
| icon_classes = [ | |
| "No Icon", | |
| "Window", | |
| "Door", | |
| "Closet", | |
| "Electrical Applience", | |
| "Toilet", | |
| "Sink", | |
| "Sauna Bench", | |
| "Fire Place", | |
| "Bathtub", | |
| "Chimney", | |
| ] | |
| discrete_cmap() # custom colormap | |
| fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(30, 15)) | |
| axes[0].set_title("Room Ground Truth") | |
| axes[0].imshow(rooms_label, cmap="rooms", vmin=0, vmax=len(room_classes) - 1) | |
| axes[1].set_title("Room Prediction") | |
| im = axes[1].imshow(rooms_pred, cmap="rooms", vmin=0, vmax=len(room_classes) - 1) | |
| cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7]) | |
| cbar = fig.colorbar(im, cax=cbar_ax, ticks=np.arange(12) + 0.5) | |
| fig.subplots_adjust(right=0.8) | |
| cbar.ax.set_yticklabels(room_classes) | |
| plt.show() | |
| fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(30, 15)) | |
| axes[0].set_title("Icon Ground Truth") | |
| axes[0].imshow(icons_label, cmap="icons", vmin=0, vmax=len(icon_classes) - 1) | |
| axes[1].set_title("Icon Prediction") | |
| im = axes[1].imshow(icons_pred, cmap="icons", vmin=0, vmax=len(icon_classes) - 1) | |
| cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7]) | |
| cbar = fig.colorbar(im, cax=cbar_ax, ticks=np.arange(11) + 0.5) | |
| fig.subplots_adjust(right=0.8) | |
| cbar.ax.set_yticklabels(icon_classes) | |
| plt.show() | |
| def polygons_to_image(polygons, types, room_polygons, room_types, height, width): | |
| pol_room_seg = np.zeros((height, width)) | |
| pol_icon_seg = np.zeros((height, width)) | |
| for i, pol in enumerate(room_polygons): | |
| mask = shp_mask(pol, np.arange(width), np.arange(height)) | |
| # jj, ii = draw.polygon(pol[:, 1], pol[:, 0]) | |
| pol_room_seg[mask] = room_types[i]["class"] | |
| for i, pol in enumerate(polygons): | |
| jj, ii = draw.polygon(pol[:, 1], pol[:, 0]) | |
| if types[i]["type"] == "wall": | |
| pol_room_seg[jj, ii] = types[i]["class"] | |
| else: | |
| pol_icon_seg[jj, ii] = types[i]["class"] | |
| return pol_room_seg, pol_icon_seg | |
| def plot_room(r, name, n_classes=12): | |
| discrete_cmap() # custom colormap | |
| plt.figure(figsize=(40, 30)) | |
| plt.axis("off") | |
| plt.tight_layout() | |
| plt.imshow(r, cmap="rooms", vmin=0, vmax=n_classes - 1) | |
| plt.savefig(name + ".png", format="png") | |
| plt.show() | |
| def plot_icon(i, name, n_classes=11): | |
| discrete_cmap() # custom colormap | |
| plt.figure(figsize=(40, 30)) | |
| plt.axis("off") | |
| plt.tight_layout() | |
| plt.imshow(i, cmap="icons", vmin=0, vmax=n_classes - 1) | |
| plt.savefig(name + ".png", format="png") | |
| plt.show() | |
| def plot_heatmaps(h, name): | |
| for index, i in enumerate(h): | |
| plt.figure(figsize=(40, 30)) | |
| plt.axis("off") | |
| plt.tight_layout() | |
| plt.imshow(i, cmap="Reds", vmin=0, vmax=1) | |
| plt.savefig(name + str(index) + ".png", format="png") | |
| plt.show() | |
| def outline_to_mask(line, x, y): | |
| """Create mask from outline contour | |
| Parameters | |
| ---------- | |
| line: array-like (N, 2) | |
| x, y: 1-D grid coordinates (input for meshgrid) | |
| Returns | |
| ------- | |
| mask : 2-D boolean array (True inside) | |
| Examples | |
| -------- | |
| >>> from shapely.geometry import Point | |
| >>> poly = Point(0,0).buffer(1) | |
| >>> x = np.linspace(-5,5,100) | |
| >>> y = np.linspace(-5,5,100) | |
| >>> mask = outline_to_mask(poly.boundary, x, y) | |
| """ | |
| mpath = mplp.Path(line) | |
| X, Y = np.meshgrid(x, y) | |
| points = np.array((X.flatten(), Y.flatten())).T | |
| mask = mpath.contains_points(points).reshape(X.shape) | |
| return mask | |
| def _grid_bbox(x, y): | |
| dx = dy = 0 | |
| return x[0] - dx / 2, x[-1] + dx / 2, y[0] - dy / 2, y[-1] + dy / 2 | |
| def _bbox_to_rect(bbox): | |
| l, r, b, t = bbox | |
| return Polygon([(l, b), (r, b), (r, t), (l, t)]) | |
| def shp_mask(shp, x, y, m=None): | |
| """ | |
| Adapted from code written by perrette | |
| form: https://gist.github.com/perrette/a78f99b76aed54b6babf3597e0b331f8 | |
| Use recursive sub-division of space and shapely contains method to create a raster mask on a regular grid. | |
| Parameters | |
| ---------- | |
| shp : shapely's Polygon (or whatever with a "contains" method and intersects method) | |
| x, y : 1-D numpy arrays defining a regular grid | |
| m : mask to fill, optional (will be created otherwise) | |
| Returns | |
| ------- | |
| m : boolean 2-D array, True inside shape. | |
| Examples | |
| -------- | |
| >>> from shapely.geometry import Point | |
| >>> poly = Point(0,0).buffer(1) | |
| >>> x = np.linspace(-5,5,100) | |
| >>> y = np.linspace(-5,5,100) | |
| >>> mask = shp_mask(poly, x, y) | |
| """ | |
| rect = _bbox_to_rect(_grid_bbox(x, y)) | |
| if m is None: | |
| m = np.zeros((y.size, x.size), dtype=bool) | |
| if not shp.intersects(rect): | |
| m[:] = False | |
| elif shp.contains(rect): | |
| m[:] = True | |
| else: | |
| k, l = m.shape | |
| if k == 1 and l == 1: | |
| m[:] = shp.contains(Point(x[0], y[0])) | |
| elif k == 1: | |
| m[:, : l // 2] = shp_mask(shp, x[: l // 2], y, m[:, : l // 2]) | |
| m[:, l // 2 :] = shp_mask(shp, x[l // 2 :], y, m[:, l // 2 :]) | |
| elif l == 1: | |
| m[: k // 2] = shp_mask(shp, x, y[: k // 2], m[: k // 2]) | |
| m[k // 2 :] = shp_mask(shp, x, y[k // 2 :], m[k // 2 :]) | |
| else: | |
| m[: k // 2, : l // 2] = shp_mask(shp, x[: l // 2], y[: k // 2], m[: k // 2, : l // 2]) | |
| m[: k // 2, l // 2 :] = shp_mask(shp, x[l // 2 :], y[: k // 2], m[: k // 2, l // 2 :]) | |
| m[k // 2 :, : l // 2] = shp_mask(shp, x[: l // 2], y[k // 2 :], m[k // 2 :, : l // 2]) | |
| m[k // 2 :, l // 2 :] = shp_mask(shp, x[l // 2 :], y[k // 2 :], m[k // 2 :, l // 2 :]) | |
| return m | |