Spaces:
Runtime error
Runtime error
anas commited on
Commit ·
fadb92b
1
Parent(s): 54d83cd
Initial deployment of Raster2Seq floor plan vectorization API
Browse files- Full inference pipeline with Gradio UI and API endpoint
- Docker-based deployment with CUDA 11.8 for GPU inference
- Semantic room detection with polygon output (13 room types)
- Checkpoint auto-downloaded during Docker build
Made-with: Cursor
This view is limited to 50 files because it contains too many changes. See raw diff
- .gitignore +15 -0
- Dockerfile +27 -0
- LICENSE +21 -0
- README.md +24 -5
- app.py +303 -0
- data_preprocess/README.md +40 -0
- data_preprocess/common_utils.py +45 -0
- data_preprocess/cubicasa5k/augmentations.py +703 -0
- data_preprocess/cubicasa5k/combine_json.py +118 -0
- data_preprocess/cubicasa5k/create_coco_cc5k.py +672 -0
- data_preprocess/cubicasa5k/floorplan_extraction.py +403 -0
- data_preprocess/cubicasa5k/house.py +1131 -0
- data_preprocess/cubicasa5k/loaders.py +158 -0
- data_preprocess/cubicasa5k/plotting.py +820 -0
- data_preprocess/cubicasa5k/run.sh +15 -0
- data_preprocess/cubicasa5k/svg_utils.py +746 -0
- data_preprocess/raster2graph/combine_json.py +122 -0
- data_preprocess/raster2graph/combine_mapping_ids.py +95 -0
- data_preprocess/raster2graph/convert_to_coco.py +472 -0
- data_preprocess/raster2graph/dataset.py +296 -0
- data_preprocess/raster2graph/image_process.py +67 -0
- data_preprocess/raster2graph/util/data_utils.py +966 -0
- data_preprocess/raster2graph/util/edges_utils.py +46 -0
- data_preprocess/raster2graph/util/geom_utils.py +124 -0
- data_preprocess/raster2graph/util/graph_utils.py +879 -0
- data_preprocess/raster2graph/util/image_id_dict.py +0 -0
- data_preprocess/raster2graph/util/math_utils.py +7 -0
- data_preprocess/raster2graph/util/mean_std.py +2 -0
- data_preprocess/raster2graph/util/metric_utils.py +338 -0
- data_preprocess/raster2graph/util/semantics_dict.py +45 -0
- data_preprocess/stru3d/PointCloudReaderPanorama.py +253 -0
- data_preprocess/stru3d/generate_coco_stru3d.py +199 -0
- data_preprocess/stru3d/generate_point_cloud_stru3d.py +32 -0
- data_preprocess/stru3d/stru3d_utils.py +244 -0
- data_preprocess/tools/plot_data.sh +60 -0
- data_preprocess/tools/run_cc5k.sh +15 -0
- data_preprocess/tools/run_r2g.sh +12 -0
- data_preprocess/tools/run_s3d.sh +22 -0
- data_preprocess/tools/run_waffle.sh +3 -0
- data_preprocess/waffle/create_coco_waffle_benchmark.py +290 -0
- datasets/__init__.py +67 -0
- datasets/data_utils.py +60 -0
- datasets/discrete_tokenizer.py +60 -0
- datasets/poly_data.py +590 -0
- datasets/room_dropout.py +237 -0
- datasets/transforms.py +46 -0
- detectron2/__init__.py +10 -0
- detectron2/checkpoint/__init__.py +11 -0
- detectron2/checkpoint/c2_model_loading.py +387 -0
- detectron2/checkpoint/catalog.py +113 -0
.gitignore
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
*.pyo
|
| 4 |
+
*.egg-info/
|
| 5 |
+
build/
|
| 6 |
+
dist/
|
| 7 |
+
data/
|
| 8 |
+
data2/
|
| 9 |
+
output*/
|
| 10 |
+
wandb*/
|
| 11 |
+
checkpoints/
|
| 12 |
+
slurm_scripts*
|
| 13 |
+
watch_folder
|
| 14 |
+
cross_eval_out
|
| 15 |
+
*.log
|
Dockerfile
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
|
| 2 |
+
|
| 3 |
+
RUN apt-get update && apt-get install -y \
|
| 4 |
+
python3.10 python3-pip git wget libgl1-mesa-glx libglib2.0-0 \
|
| 5 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 6 |
+
|
| 7 |
+
RUN useradd -m -u 1000 user
|
| 8 |
+
WORKDIR /app
|
| 9 |
+
COPY . /app
|
| 10 |
+
|
| 11 |
+
RUN pip3 install torch==2.3.1 torchvision==0.18.1 \
|
| 12 |
+
--index-url https://download.pytorch.org/whl/cu118
|
| 13 |
+
RUN pip3 install -r requirements.txt
|
| 14 |
+
RUN pip3 install gradio gdown
|
| 15 |
+
|
| 16 |
+
RUN cd models/ops && sh make.sh && cd ../..
|
| 17 |
+
RUN cd diff_ras && python3 setup.py build develop && cd ..
|
| 18 |
+
|
| 19 |
+
RUN mkdir -p checkpoints && \
|
| 20 |
+
gdown --fuzzy "https://drive.google.com/file/d/1M32HlYwXw-4Q_uajSCvpbF31UFPzQVHP/view?usp=sharing" \
|
| 21 |
+
-O checkpoints/r2g_res256_ep0849.pth
|
| 22 |
+
|
| 23 |
+
RUN chown -R user:user /app
|
| 24 |
+
USER user
|
| 25 |
+
|
| 26 |
+
EXPOSE 7860
|
| 27 |
+
CMD ["python3", "app.py"]
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2026 Raster2Seq
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -1,10 +1,29 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Raster2Seq
|
| 3 |
+
emoji: 🏠
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
+
app_port: 7860
|
| 9 |
---
|
| 10 |
|
| 11 |
+
# Raster2Seq - Floor Plan Vectorization
|
| 12 |
+
|
| 13 |
+
Upload a floor plan image to detect room polygons and their semantic labels.
|
| 14 |
+
|
| 15 |
+
This Space runs the Raster2Seq model for converting raster floor plan images into vectorized polygon sequences with room type classification.
|
| 16 |
+
|
| 17 |
+
## API Usage
|
| 18 |
+
|
| 19 |
+
This Space exposes a Gradio API. You can call it programmatically:
|
| 20 |
+
|
| 21 |
+
```python
|
| 22 |
+
from gradio_client import Client
|
| 23 |
+
|
| 24 |
+
client = Client("AGLO-AI/raster2seq")
|
| 25 |
+
result = client.predict(
|
| 26 |
+
image="path/to/floorplan.png",
|
| 27 |
+
api_name="/predict"
|
| 28 |
+
)
|
| 29 |
+
```
|
app.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import copy
|
| 3 |
+
import json
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
import cv2
|
| 7 |
+
import gradio as gr
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from shapely.geometry import Polygon
|
| 12 |
+
|
| 13 |
+
from datasets.discrete_tokenizer import DiscreteTokenizer
|
| 14 |
+
from datasets.transforms import ResizeAndPad
|
| 15 |
+
from detectron2.data import transforms as T
|
| 16 |
+
from models import build_model
|
| 17 |
+
from util.plot_utils import plot_semantic_rich_floorplan_opencv
|
| 18 |
+
|
| 19 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 20 |
+
|
| 21 |
+
MODEL_ARGS = argparse.Namespace(
|
| 22 |
+
poly2seq=True,
|
| 23 |
+
seq_len=512,
|
| 24 |
+
num_bins=32,
|
| 25 |
+
image_size=256,
|
| 26 |
+
input_channels=3,
|
| 27 |
+
backbone="resnet50",
|
| 28 |
+
dilation=False,
|
| 29 |
+
position_embedding="sine",
|
| 30 |
+
position_embedding_scale=2 * np.pi,
|
| 31 |
+
num_feature_levels=4,
|
| 32 |
+
enc_layers=6,
|
| 33 |
+
dec_layers=6,
|
| 34 |
+
dim_feedforward=1024,
|
| 35 |
+
hidden_dim=256,
|
| 36 |
+
dropout=0.1,
|
| 37 |
+
nheads=8,
|
| 38 |
+
num_queries=800,
|
| 39 |
+
num_polys=20,
|
| 40 |
+
dec_n_points=4,
|
| 41 |
+
enc_n_points=4,
|
| 42 |
+
query_pos_type="sine",
|
| 43 |
+
with_poly_refine=False,
|
| 44 |
+
masked_attn=False,
|
| 45 |
+
semantic_classes=13,
|
| 46 |
+
aux_loss=False,
|
| 47 |
+
dec_attn_concat_src=True,
|
| 48 |
+
pre_decoder_pos_embed=False,
|
| 49 |
+
learnable_dec_pe=False,
|
| 50 |
+
dec_qkv_proj=False,
|
| 51 |
+
per_token_sem_loss=True,
|
| 52 |
+
add_cls_token=False,
|
| 53 |
+
use_anchor=True,
|
| 54 |
+
inject_cls_embed=False,
|
| 55 |
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
R2G_LABEL = {
|
| 59 |
+
0: "Living Room",
|
| 60 |
+
1: "Kitchen",
|
| 61 |
+
2: "Bedroom",
|
| 62 |
+
3: "Bathroom",
|
| 63 |
+
4: "Balcony",
|
| 64 |
+
5: "Corridor",
|
| 65 |
+
6: "Dining Room",
|
| 66 |
+
7: "Study",
|
| 67 |
+
8: "Studio",
|
| 68 |
+
9: "Store Room",
|
| 69 |
+
10: "Garden",
|
| 70 |
+
11: "Laundry Room",
|
| 71 |
+
12: "Office",
|
| 72 |
+
13: "Basement",
|
| 73 |
+
14: "Garage",
|
| 74 |
+
15: "Undefined",
|
| 75 |
+
16: "Door",
|
| 76 |
+
17: "Window",
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def _process_predictions(
|
| 81 |
+
pred_corners, i, semantic_rich, image_size, pred_room_label,
|
| 82 |
+
pred_room_logits, per_token_sem_loss, add_cls_token=False,
|
| 83 |
+
):
|
| 84 |
+
"""Extract polygons from poly2seq model output."""
|
| 85 |
+
np_softmax = lambda x: np.exp(x) / np.sum(np.exp(x), axis=-1, keepdims=True)
|
| 86 |
+
pred_corners_per_scene = pred_corners[i]
|
| 87 |
+
room_polys = []
|
| 88 |
+
|
| 89 |
+
if semantic_rich:
|
| 90 |
+
room_types = []
|
| 91 |
+
window_doors = []
|
| 92 |
+
window_doors_types = []
|
| 93 |
+
pred_room_label_per_scene = pred_room_label[i].cpu().numpy()
|
| 94 |
+
pred_room_logit_per_scene = pred_room_logits[i].cpu().numpy()
|
| 95 |
+
|
| 96 |
+
all_room_polys = []
|
| 97 |
+
tmp = []
|
| 98 |
+
all_length_list = [0]
|
| 99 |
+
|
| 100 |
+
for j in range(len(pred_corners_per_scene)):
|
| 101 |
+
if isinstance(pred_corners_per_scene[j], int):
|
| 102 |
+
if pred_corners_per_scene[j] == 2 and tmp:
|
| 103 |
+
all_room_polys.append(tmp)
|
| 104 |
+
all_length_list.append(len(tmp) + 1 + add_cls_token)
|
| 105 |
+
tmp = []
|
| 106 |
+
continue
|
| 107 |
+
tmp.append(pred_corners_per_scene[j])
|
| 108 |
+
|
| 109 |
+
if len(tmp):
|
| 110 |
+
all_room_polys.append(tmp)
|
| 111 |
+
all_length_list.append(len(tmp) + 1 + add_cls_token)
|
| 112 |
+
|
| 113 |
+
start_poly_indices = np.cumsum(all_length_list)
|
| 114 |
+
final_pred_classes = []
|
| 115 |
+
|
| 116 |
+
for j, poly in enumerate(all_room_polys):
|
| 117 |
+
if len(poly) < 2:
|
| 118 |
+
continue
|
| 119 |
+
corners = np.array(poly, dtype=np.float32) * (image_size - 1)
|
| 120 |
+
corners = np.around(corners).astype(np.int32)
|
| 121 |
+
|
| 122 |
+
if not semantic_rich:
|
| 123 |
+
if len(corners) >= 4 and Polygon(corners).area >= 100:
|
| 124 |
+
room_polys.append(corners)
|
| 125 |
+
else:
|
| 126 |
+
if per_token_sem_loss:
|
| 127 |
+
pred_classes, counts = np.unique(
|
| 128 |
+
pred_room_label_per_scene[start_poly_indices[j]:start_poly_indices[j + 1]][:-1],
|
| 129 |
+
return_counts=True,
|
| 130 |
+
)
|
| 131 |
+
pred_class = pred_classes[np.argmax(counts)]
|
| 132 |
+
else:
|
| 133 |
+
pred_class = pred_room_label_per_scene[start_poly_indices[j + 1] - 1]
|
| 134 |
+
final_pred_classes.append(pred_class)
|
| 135 |
+
|
| 136 |
+
if len(corners) >= 3 and Polygon(corners).area >= 100:
|
| 137 |
+
room_polys.append(corners)
|
| 138 |
+
room_types.append(pred_class)
|
| 139 |
+
elif len(corners) == 2:
|
| 140 |
+
window_doors.append(corners)
|
| 141 |
+
window_doors_types.append(pred_class)
|
| 142 |
+
|
| 143 |
+
if not semantic_rich:
|
| 144 |
+
pred_room_label_per_scene = len(all_room_polys) * [-1]
|
| 145 |
+
|
| 146 |
+
return {
|
| 147 |
+
"room_polys": room_polys,
|
| 148 |
+
"room_types": room_types if semantic_rich else None,
|
| 149 |
+
"window_doors": window_doors if semantic_rich else None,
|
| 150 |
+
"window_doors_types": window_doors_types if semantic_rich else None,
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
@torch.no_grad()
|
| 155 |
+
def generate(model, samples, semantic_rich=False, use_cache=True, per_token_sem_loss=False):
|
| 156 |
+
"""Generate room polygons from model predictions (poly2seq mode only)."""
|
| 157 |
+
model.eval()
|
| 158 |
+
image_size = samples[0].size(2)
|
| 159 |
+
|
| 160 |
+
outputs = model.forward_inference(samples, use_cache)
|
| 161 |
+
pred_corners = outputs["gen_out"]
|
| 162 |
+
|
| 163 |
+
bs = outputs["pred_logits"].shape[0]
|
| 164 |
+
|
| 165 |
+
pred_room_label = None
|
| 166 |
+
pred_room_logits = None
|
| 167 |
+
if "pred_room_logits" in outputs:
|
| 168 |
+
pred_room_logits = outputs["pred_room_logits"]
|
| 169 |
+
prob = torch.nn.functional.softmax(pred_room_logits, -1)
|
| 170 |
+
_, pred_room_label = prob[..., :-1].max(-1)
|
| 171 |
+
|
| 172 |
+
result_rooms = []
|
| 173 |
+
result_classes = []
|
| 174 |
+
|
| 175 |
+
for i in range(bs):
|
| 176 |
+
scene_outputs = _process_predictions(
|
| 177 |
+
pred_corners, i, semantic_rich, image_size,
|
| 178 |
+
pred_room_label, pred_room_logits, per_token_sem_loss,
|
| 179 |
+
)
|
| 180 |
+
room_polys = scene_outputs["room_polys"]
|
| 181 |
+
room_types = scene_outputs["room_types"]
|
| 182 |
+
window_doors = scene_outputs["window_doors"]
|
| 183 |
+
window_doors_types = scene_outputs["window_doors_types"]
|
| 184 |
+
|
| 185 |
+
if window_doors:
|
| 186 |
+
result_rooms.append(room_polys + window_doors)
|
| 187 |
+
result_classes.append(room_types + window_doors_types)
|
| 188 |
+
else:
|
| 189 |
+
result_rooms.append(room_polys)
|
| 190 |
+
result_classes.append(room_types)
|
| 191 |
+
|
| 192 |
+
return {"room": result_rooms, "labels": result_classes}
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def load_model():
|
| 196 |
+
tokenizer = DiscreteTokenizer(
|
| 197 |
+
MODEL_ARGS.num_bins, MODEL_ARGS.seq_len, add_cls=MODEL_ARGS.add_cls_token
|
| 198 |
+
)
|
| 199 |
+
MODEL_ARGS.vocab_size = len(tokenizer)
|
| 200 |
+
|
| 201 |
+
model = build_model(MODEL_ARGS, train=False, tokenizer=tokenizer)
|
| 202 |
+
model.to(DEVICE)
|
| 203 |
+
|
| 204 |
+
ckpt_path = "checkpoints/r2g_res256_ep0849.pth"
|
| 205 |
+
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
| 206 |
+
ckpt_state_dict = copy.deepcopy(checkpoint["ema"])
|
| 207 |
+
for key in list(ckpt_state_dict.keys()):
|
| 208 |
+
if key.startswith("module."):
|
| 209 |
+
ckpt_state_dict[key[7:]] = ckpt_state_dict.pop(key)
|
| 210 |
+
model.load_state_dict(ckpt_state_dict, strict=False)
|
| 211 |
+
|
| 212 |
+
for param in model.parameters():
|
| 213 |
+
param.requires_grad = False
|
| 214 |
+
model.eval()
|
| 215 |
+
return model
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
print("Loading model...")
|
| 219 |
+
MODEL = load_model()
|
| 220 |
+
print("Model loaded.")
|
| 221 |
+
|
| 222 |
+
DATA_TRANSFORM = T.AugmentationList(
|
| 223 |
+
[ResizeAndPad((MODEL_ARGS.image_size, MODEL_ARGS.image_size), pad_value=255)]
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def preprocess_image(pil_image: Image.Image) -> torch.Tensor:
|
| 228 |
+
image_np = np.array(pil_image.convert("RGB"))
|
| 229 |
+
aug_input = T.AugInput(image_np)
|
| 230 |
+
DATA_TRANSFORM(aug_input)
|
| 231 |
+
image_np = aug_input.image
|
| 232 |
+
|
| 233 |
+
if len(image_np.shape) == 2:
|
| 234 |
+
tensor = np.expand_dims(image_np, 0)
|
| 235 |
+
else:
|
| 236 |
+
tensor = image_np.transpose((2, 0, 1))
|
| 237 |
+
|
| 238 |
+
return (1 / 255) * torch.as_tensor(tensor, dtype=torch.float32)
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def predict_floorplan(image: Image.Image):
|
| 242 |
+
if image is None:
|
| 243 |
+
return None, json.dumps({"error": "No image provided"})
|
| 244 |
+
|
| 245 |
+
input_tensor = preprocess_image(image).unsqueeze(0).to(DEVICE)
|
| 246 |
+
|
| 247 |
+
outputs = generate(
|
| 248 |
+
MODEL,
|
| 249 |
+
input_tensor,
|
| 250 |
+
semantic_rich=MODEL_ARGS.semantic_classes > 0,
|
| 251 |
+
use_cache=True,
|
| 252 |
+
per_token_sem_loss=MODEL_ARGS.per_token_sem_loss,
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
pred_rooms = outputs["room"][0]
|
| 256 |
+
pred_labels = outputs["labels"][0]
|
| 257 |
+
image_size = MODEL_ARGS.image_size
|
| 258 |
+
|
| 259 |
+
if pred_labels is None:
|
| 260 |
+
pred_labels = [-1] * len(pred_rooms)
|
| 261 |
+
|
| 262 |
+
result_polygons = []
|
| 263 |
+
for poly, label in zip(pred_rooms, pred_labels):
|
| 264 |
+
coords = poly.astype(float).tolist()
|
| 265 |
+
result_polygons.append({
|
| 266 |
+
"label": R2G_LABEL.get(int(label), "Unknown"),
|
| 267 |
+
"label_id": int(label),
|
| 268 |
+
"polygon": coords,
|
| 269 |
+
})
|
| 270 |
+
|
| 271 |
+
floorplan_map = plot_semantic_rich_floorplan_opencv(
|
| 272 |
+
zip(pred_rooms, pred_labels),
|
| 273 |
+
None,
|
| 274 |
+
door_window_index=[],
|
| 275 |
+
semantics_label_mapping=R2G_LABEL,
|
| 276 |
+
plot_text=True,
|
| 277 |
+
one_color=False,
|
| 278 |
+
is_sem=True,
|
| 279 |
+
img_w=image_size * 2,
|
| 280 |
+
img_h=image_size * 2,
|
| 281 |
+
scale=2,
|
| 282 |
+
)
|
| 283 |
+
if floorplan_map is not None and floorplan_map.size > 0:
|
| 284 |
+
floorplan_rgb = cv2.cvtColor(floorplan_map, cv2.COLOR_BGR2RGB)
|
| 285 |
+
vis_image = Image.fromarray(floorplan_rgb)
|
| 286 |
+
else:
|
| 287 |
+
vis_image = None
|
| 288 |
+
|
| 289 |
+
return vis_image, result_polygons
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
demo = gr.Interface(
|
| 293 |
+
fn=predict_floorplan,
|
| 294 |
+
inputs=gr.Image(type="pil", label="Floor Plan Image"),
|
| 295 |
+
outputs=[
|
| 296 |
+
gr.Image(type="pil", label="Detected Rooms"),
|
| 297 |
+
gr.JSON(label="Detected Polygons"),
|
| 298 |
+
],
|
| 299 |
+
title="Raster2Seq - Floor Plan Vectorization",
|
| 300 |
+
description="Upload a floor plan image to detect room polygons and their semantic labels. Returns both a visualization and structured JSON with polygon coordinates.",
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|
data_preprocess/README.md
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Data preprocessing
|
| 2 |
+
|
| 3 |
+
### Structured3D
|
| 4 |
+
|
| 5 |
+
Simply download preprocessed data by RoomFormer at [here](https://polybox.ethz.ch/index.php/s/wKYWFsQOXHnkwcG). For more details, please refer to [RoomFormer's instructions](https://github.com/ywyue/RoomFormer/tree/main/data_preprocess).
|
| 6 |
+
|
| 7 |
+
To render binary floorplan images from GT annotations (as used in our paper), run `bash data_preprocess/tools/run_s3d.sh`.
|
| 8 |
+
|
| 9 |
+
### CubiCasa5K
|
| 10 |
+
Step 1: Download and extract [CubiCasa5K](https://zenodo.org/record/2613548) dataset.
|
| 11 |
+
|
| 12 |
+
Step 2: Run `bash data_preprocess/cubicasa5k/run.sh`.
|
| 13 |
+
|
| 14 |
+
### Raster2Graph
|
| 15 |
+
The instruction mainly follows Raster2Graph's instruction.
|
| 16 |
+
|
| 17 |
+
Step 1: Due to dataset proprietary restrictions, please apply for access to LIFULL HOME'S Data [here](https://www.nii.ac.jp/dsc/idr/en/lifull/).
|
| 18 |
+
|
| 19 |
+
Step 2: After obtaining access, download only the "photo-rent-madori-full-00" folder, which contains approximately 300,000 images.
|
| 20 |
+
|
| 21 |
+
Step 3: Apply for access to the annotation [here](https://docs.google.com/forms/d/e/1FAIpQLSexqNMjyvPMtPMPN7bSh_1u4Q27LZAT-S9lR_gpipNIMKV5lw/viewform).
|
| 22 |
+
|
| 23 |
+
The package has 3 folders:
|
| 24 |
+
- annot_npy, annot_json: the annotations saved in npy and json, respectively.
|
| 25 |
+
- original_vector_boundary: boundary boxes of "LIFULL HOME'S Data" which is used to create centered 512x512 images.
|
| 26 |
+
|
| 27 |
+
These folders should be saved in the same directory as `photo-rent-madori-full-00`. For example: `data/R2G_hr_dataset/`.
|
| 28 |
+
|
| 29 |
+
Step 4: Run `bash data_preprocess/tools/run_r2g.sh`.
|
| 30 |
+
|
| 31 |
+
### WAFFLE
|
| 32 |
+
|
| 33 |
+
It is noted that since WAFFLE only provides segmentation masks for a subset of 100 examples, so we only process this subset for the evaluation, not for training.
|
| 34 |
+
|
| 35 |
+
Step 1: Download data at [here](https://tauex-my.sharepoint.com/:f:/g/personal/hadarelor_tauex_tau_ac_il/EqMX9nRbJ9xFiK7dR_m07b8BldS2saoZ4-ockqncJb_Hrg?e=zGIuos)
|
| 36 |
+
|
| 37 |
+
Step 2: Run `bash data_preprocess/tools/run_waffle.sh`.
|
| 38 |
+
|
| 39 |
+
## Data visualization
|
| 40 |
+
Please refer to this script [tools/plot_data.sh](tools/plot_data.sh).
|
data_preprocess/common_utils.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
from plyfile import PlyData
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def read_scene_pc(file_path):
|
| 9 |
+
with open(file_path, "rb") as f:
|
| 10 |
+
plydata = PlyData.read(f)
|
| 11 |
+
dtype = plydata["vertex"].data.dtype
|
| 12 |
+
print("dtype of file{}: {}".format(file_path, dtype))
|
| 13 |
+
|
| 14 |
+
points_data = np.array(plydata["vertex"].data.tolist())
|
| 15 |
+
|
| 16 |
+
return points_data
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def is_clockwise(points):
|
| 20 |
+
# points is a list of 2d points.
|
| 21 |
+
assert len(points) > 0
|
| 22 |
+
s = 0.0
|
| 23 |
+
for p1, p2 in zip(points, points[1:] + [points[0]]):
|
| 24 |
+
s += (p2[0] - p1[0]) * (p2[1] + p1[1])
|
| 25 |
+
return s > 0.0
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def resort_corners(corners):
|
| 29 |
+
# re-find the starting point and sort corners clockwisely
|
| 30 |
+
x_y_square_sum = corners[:, 0] ** 2 + corners[:, 1] ** 2
|
| 31 |
+
start_corner_idx = np.argmin(x_y_square_sum)
|
| 32 |
+
|
| 33 |
+
corners_sorted = np.concatenate([corners[start_corner_idx:], corners[:start_corner_idx]])
|
| 34 |
+
|
| 35 |
+
## sort points clockwise
|
| 36 |
+
if not is_clockwise(corners_sorted[:, :2].tolist()):
|
| 37 |
+
corners_sorted[1:] = np.flip(corners_sorted[1:], 0)
|
| 38 |
+
|
| 39 |
+
return corners
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def export_density(density_map, out_folder, scene_id):
|
| 43 |
+
density_path = os.path.join(out_folder, scene_id + ".png")
|
| 44 |
+
density_uint8 = (density_map * 255).astype(np.uint8)
|
| 45 |
+
cv2.imwrite(density_path, density_uint8)
|
data_preprocess/cubicasa5k/augmentations.py
ADDED
|
@@ -0,0 +1,703 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from math import inf
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from floortrans.loaders import svg_utils
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Compose(object):
|
| 11 |
+
def __init__(self, augmentations):
|
| 12 |
+
self.augmentations = augmentations
|
| 13 |
+
|
| 14 |
+
def __call__(self, sample):
|
| 15 |
+
for a in self.augmentations:
|
| 16 |
+
sample = a(sample)
|
| 17 |
+
|
| 18 |
+
return sample
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# 0. I
|
| 22 |
+
# 1. I top to right
|
| 23 |
+
# 2. I vertical flip
|
| 24 |
+
# 3. I top to left
|
| 25 |
+
# 4. L horizontal flip
|
| 26 |
+
# 5. L
|
| 27 |
+
# 6. L vertical flip
|
| 28 |
+
# 7. L horizontal and vertical flip
|
| 29 |
+
# 8. T
|
| 30 |
+
# 9. T top to right
|
| 31 |
+
# 10. T top to down
|
| 32 |
+
# 11. T top to left
|
| 33 |
+
# 12. X or +
|
| 34 |
+
# 13. Opening left corner
|
| 35 |
+
# 14. Opening right corner
|
| 36 |
+
# 15. Opening up corner
|
| 37 |
+
# 16. Opening down corer
|
| 38 |
+
# 17. Icon upper left
|
| 39 |
+
# 18. Icon upper right
|
| 40 |
+
# 19. Icon lower left
|
| 41 |
+
# 20. Icon lower right
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class RandomRotations(object):
|
| 45 |
+
def __init__(self, format="furu"):
|
| 46 |
+
if format == "furu":
|
| 47 |
+
self.augment = self.furu
|
| 48 |
+
elif format == "cubi":
|
| 49 |
+
self.augment = self.cubi
|
| 50 |
+
|
| 51 |
+
def __call__(self, sample):
|
| 52 |
+
return self.augment(sample)
|
| 53 |
+
|
| 54 |
+
def cubi(self, sample):
|
| 55 |
+
fplan = sample["image"]
|
| 56 |
+
segmentation = sample["label"]
|
| 57 |
+
heatmap_points = sample["heatmaps"]
|
| 58 |
+
scale = sample["scale"]
|
| 59 |
+
num_of_rotations = int(torch.randint(0, 3, (1,)))
|
| 60 |
+
hmapp_convert_map = {
|
| 61 |
+
0: 1,
|
| 62 |
+
1: 2,
|
| 63 |
+
2: 3,
|
| 64 |
+
3: 0,
|
| 65 |
+
4: 5,
|
| 66 |
+
5: 6,
|
| 67 |
+
6: 7,
|
| 68 |
+
7: 4,
|
| 69 |
+
8: 9,
|
| 70 |
+
9: 10,
|
| 71 |
+
10: 11,
|
| 72 |
+
11: 8,
|
| 73 |
+
12: 12,
|
| 74 |
+
13: 15,
|
| 75 |
+
14: 16,
|
| 76 |
+
15: 14,
|
| 77 |
+
16: 13,
|
| 78 |
+
17: 18,
|
| 79 |
+
18: 20,
|
| 80 |
+
19: 17,
|
| 81 |
+
20: 19,
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
for i in range(num_of_rotations):
|
| 85 |
+
fplan = fplan.transpose(2, 1).flip(2)
|
| 86 |
+
segmentation = segmentation.transpose(2, 1).flip(2)
|
| 87 |
+
points_rotated = dict()
|
| 88 |
+
for junction_type, points in heatmap_points.items():
|
| 89 |
+
new_junction_type = hmapp_convert_map[junction_type]
|
| 90 |
+
new_heatmap_points = []
|
| 91 |
+
for point in points:
|
| 92 |
+
x = fplan.shape[1] - 1 - point[1]
|
| 93 |
+
y = point[0]
|
| 94 |
+
# if y > 256 or x > 256:
|
| 95 |
+
# __import__('ipdb').set_trace()
|
| 96 |
+
new_heatmap_points.append([x, y])
|
| 97 |
+
|
| 98 |
+
points_rotated[new_junction_type] = new_heatmap_points
|
| 99 |
+
|
| 100 |
+
heatmap_points = points_rotated
|
| 101 |
+
|
| 102 |
+
sample = {"image": fplan, "label": segmentation, "scale": scale, "heatmaps": heatmap_points}
|
| 103 |
+
|
| 104 |
+
return sample
|
| 105 |
+
|
| 106 |
+
def furu(self, sample):
|
| 107 |
+
fplan = sample["image"]
|
| 108 |
+
segmentation = sample["label"]
|
| 109 |
+
heatmap_points = sample["heatmap_points"]
|
| 110 |
+
num_of_rotations = int(torch.randint(0, 3, (1,)))
|
| 111 |
+
for i in range(num_of_rotations):
|
| 112 |
+
fplan = fplan.transpose(2, 1).flip(2)
|
| 113 |
+
segmentation = segmentation.transpose(2, 1).flip(2)
|
| 114 |
+
|
| 115 |
+
hmapp_convert_map = {
|
| 116 |
+
0: 1,
|
| 117 |
+
1: 2,
|
| 118 |
+
2: 3,
|
| 119 |
+
3: 0,
|
| 120 |
+
4: 5,
|
| 121 |
+
5: 6,
|
| 122 |
+
6: 7,
|
| 123 |
+
7: 4,
|
| 124 |
+
8: 9,
|
| 125 |
+
9: 10,
|
| 126 |
+
10: 11,
|
| 127 |
+
11: 8,
|
| 128 |
+
12: 12,
|
| 129 |
+
13: 15,
|
| 130 |
+
14: 16,
|
| 131 |
+
15: 14,
|
| 132 |
+
16: 13,
|
| 133 |
+
17: 18,
|
| 134 |
+
18: 20,
|
| 135 |
+
19: 17,
|
| 136 |
+
20: 19,
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
points_rotated = dict()
|
| 140 |
+
for junction_type, points in heatmap_points.items():
|
| 141 |
+
new_junction_type = hmapp_convert_map[junction_type]
|
| 142 |
+
new_heatmap_points = []
|
| 143 |
+
for point in points:
|
| 144 |
+
new_heatmap_points.append([fplan.shape[1] - 1 - point[1], point[0]])
|
| 145 |
+
|
| 146 |
+
points_rotated[new_junction_type] = new_heatmap_points
|
| 147 |
+
|
| 148 |
+
heatmap_points = points_rotated
|
| 149 |
+
|
| 150 |
+
sample = {"image": fplan, "label": segmentation, "heatmap_points": heatmap_points}
|
| 151 |
+
|
| 152 |
+
return sample
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def clip_heatmaps(heatmaps, minx, maxx, miny, maxy):
|
| 156 |
+
def clip(p):
|
| 157 |
+
return p[0] < maxx and p[0] >= minx and p[1] < maxy and p[1] >= miny
|
| 158 |
+
|
| 159 |
+
res = {}
|
| 160 |
+
for key, value in heatmaps.items():
|
| 161 |
+
res[key] = list(filter(clip, value))
|
| 162 |
+
for i, e in enumerate(res[key]):
|
| 163 |
+
res[key][i] = (e[0] - minx, e[1] - miny)
|
| 164 |
+
|
| 165 |
+
return res
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class DictToTensor(object):
|
| 169 |
+
def __init__(self, data_format="cubi"):
|
| 170 |
+
if data_format == "cubi":
|
| 171 |
+
self.augment = self.cubi
|
| 172 |
+
elif data_format == "furukawa":
|
| 173 |
+
self.augment = self.furukawa
|
| 174 |
+
|
| 175 |
+
def __call__(self, sample):
|
| 176 |
+
return self.augment(sample)
|
| 177 |
+
|
| 178 |
+
def cubi(self, sample):
|
| 179 |
+
image, label = sample["image"], sample["label"]
|
| 180 |
+
_, height, width = label.shape
|
| 181 |
+
heatmaps = sample["heatmaps"]
|
| 182 |
+
scale = sample["scale"]
|
| 183 |
+
|
| 184 |
+
heatmap_tensor = np.zeros((21, height, width))
|
| 185 |
+
for channel, coords in heatmaps.items():
|
| 186 |
+
for x, y in coords:
|
| 187 |
+
if x >= width:
|
| 188 |
+
x -= 1
|
| 189 |
+
if y >= height:
|
| 190 |
+
y -= 1
|
| 191 |
+
heatmap_tensor[int(channel), int(y), int(x)] = 1
|
| 192 |
+
|
| 193 |
+
# Gaussian filter
|
| 194 |
+
kernel = svg_utils.get_gaussian2D(int(30 * scale))
|
| 195 |
+
for i, h in enumerate(heatmap_tensor):
|
| 196 |
+
heatmap_tensor[i] = cv2.filter2D(h, -1, kernel)
|
| 197 |
+
|
| 198 |
+
heatmap_tensor = torch.FloatTensor(heatmap_tensor)
|
| 199 |
+
|
| 200 |
+
label = torch.cat((heatmap_tensor, label), 0)
|
| 201 |
+
|
| 202 |
+
return {"image": image, "label": label}
|
| 203 |
+
|
| 204 |
+
def furukawa(self, sample):
|
| 205 |
+
image, label = sample["image"], sample["label"]
|
| 206 |
+
_, height, width = label.shape
|
| 207 |
+
heatmap_points = sample["heatmap_points"]
|
| 208 |
+
|
| 209 |
+
heatmap_tensor = np.zeros((21, height, width))
|
| 210 |
+
for channel, coords in heatmap_points.items():
|
| 211 |
+
for x, y in coords:
|
| 212 |
+
heatmap_tensor[int(channel), int(y), int(x)] = 1
|
| 213 |
+
|
| 214 |
+
# Gaussian filter
|
| 215 |
+
kernel = svg_utils.get_gaussian2D(13)
|
| 216 |
+
for i, h in enumerate(heatmap_tensor):
|
| 217 |
+
heatmap_tensor[i] = cv2.filter2D(h, -1, kernel, borderType=cv2.BORDER_CONSTANT, delta=0)
|
| 218 |
+
|
| 219 |
+
heatmap_tensor = torch.FloatTensor(heatmap_tensor)
|
| 220 |
+
|
| 221 |
+
label = torch.cat((heatmap_tensor, label), 0)
|
| 222 |
+
|
| 223 |
+
return {"image": image, "label": label}
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
class RotateNTurns(object):
|
| 227 |
+
def rot_tensor(self, t, n):
|
| 228 |
+
# One turn clock wise
|
| 229 |
+
if n == 1:
|
| 230 |
+
t = t.flip(2).transpose(3, 2)
|
| 231 |
+
# One turn counter clock wise
|
| 232 |
+
elif n == -1:
|
| 233 |
+
t = t.transpose(3, 2).flip(2)
|
| 234 |
+
# Two turns clock wise
|
| 235 |
+
elif n == 2:
|
| 236 |
+
t = t.flip(2).flip(3)
|
| 237 |
+
|
| 238 |
+
return t
|
| 239 |
+
|
| 240 |
+
def rot_points(self, t, n):
|
| 241 |
+
# Swapping corner ts
|
| 242 |
+
t_sorted = t.clone().detach()
|
| 243 |
+
# One turn clock wise
|
| 244 |
+
if n == 1:
|
| 245 |
+
# I junctions
|
| 246 |
+
t_sorted[:, 1] = t[:, 0]
|
| 247 |
+
t_sorted[:, 2] = t[:, 1]
|
| 248 |
+
t_sorted[:, 3] = t[:, 2]
|
| 249 |
+
t_sorted[:, 0] = t[:, 3]
|
| 250 |
+
# L junctions
|
| 251 |
+
t_sorted[:, 5] = t[:, 4]
|
| 252 |
+
t_sorted[:, 6] = t[:, 5]
|
| 253 |
+
t_sorted[:, 7] = t[:, 6]
|
| 254 |
+
t_sorted[:, 4] = t[:, 7]
|
| 255 |
+
# T junctions
|
| 256 |
+
t_sorted[:, 9] = t[:, 8]
|
| 257 |
+
t_sorted[:, 10] = t[:, 9]
|
| 258 |
+
t_sorted[:, 11] = t[:, 10]
|
| 259 |
+
t_sorted[:, 8] = t[:, 11]
|
| 260 |
+
# Opening corners
|
| 261 |
+
t_sorted[:, 15] = t[:, 13]
|
| 262 |
+
t_sorted[:, 16] = t[:, 14]
|
| 263 |
+
t_sorted[:, 14] = t[:, 15]
|
| 264 |
+
t_sorted[:, 13] = t[:, 16]
|
| 265 |
+
# Icon corners
|
| 266 |
+
t_sorted[:, 18] = t[:, 17]
|
| 267 |
+
t_sorted[:, 20] = t[:, 18]
|
| 268 |
+
t_sorted[:, 17] = t[:, 19]
|
| 269 |
+
t_sorted[:, 19] = t[:, 20]
|
| 270 |
+
# One turn counter clock wise
|
| 271 |
+
elif n == -1:
|
| 272 |
+
# I junctions
|
| 273 |
+
t_sorted[:, 3] = t[:, 0]
|
| 274 |
+
t_sorted[:, 0] = t[:, 1]
|
| 275 |
+
t_sorted[:, 1] = t[:, 2]
|
| 276 |
+
t_sorted[:, 2] = t[:, 3]
|
| 277 |
+
# L junctions
|
| 278 |
+
t_sorted[:, 7] = t[:, 4]
|
| 279 |
+
t_sorted[:, 4] = t[:, 5]
|
| 280 |
+
t_sorted[:, 5] = t[:, 6]
|
| 281 |
+
t_sorted[:, 6] = t[:, 7]
|
| 282 |
+
# T junctions
|
| 283 |
+
t_sorted[:, 11] = t[:, 8]
|
| 284 |
+
t_sorted[:, 8] = t[:, 9]
|
| 285 |
+
t_sorted[:, 9] = t[:, 10]
|
| 286 |
+
t_sorted[:, 10] = t[:, 11]
|
| 287 |
+
# Opening corners
|
| 288 |
+
t_sorted[:, 16] = t[:, 13]
|
| 289 |
+
t_sorted[:, 15] = t[:, 14]
|
| 290 |
+
t_sorted[:, 13] = t[:, 15]
|
| 291 |
+
t_sorted[:, 14] = t[:, 16]
|
| 292 |
+
# Icon corners
|
| 293 |
+
t_sorted[:, 19] = t[:, 17]
|
| 294 |
+
t_sorted[:, 17] = t[:, 18]
|
| 295 |
+
t_sorted[:, 20] = t[:, 19]
|
| 296 |
+
t_sorted[:, 18] = t[:, 20]
|
| 297 |
+
# Two turns clock wise
|
| 298 |
+
elif n == 2:
|
| 299 |
+
t_sorted = t.clone().detach()
|
| 300 |
+
# I junctions
|
| 301 |
+
t_sorted[:, 2] = t[:, 0]
|
| 302 |
+
t_sorted[:, 3] = t[:, 1]
|
| 303 |
+
t_sorted[:, 0] = t[:, 2]
|
| 304 |
+
t_sorted[:, 4] = t[:, 3]
|
| 305 |
+
# L junctions
|
| 306 |
+
t_sorted[:, 6] = t[:, 4]
|
| 307 |
+
t_sorted[:, 7] = t[:, 5]
|
| 308 |
+
t_sorted[:, 4] = t[:, 6]
|
| 309 |
+
t_sorted[:, 5] = t[:, 7]
|
| 310 |
+
# T junctions
|
| 311 |
+
t_sorted[:, 10] = t[:, 8]
|
| 312 |
+
t_sorted[:, 11] = t[:, 9]
|
| 313 |
+
t_sorted[:, 8] = t[:, 10]
|
| 314 |
+
t_sorted[:, 9] = t[:, 11]
|
| 315 |
+
# Opening corners
|
| 316 |
+
t_sorted[:, 14] = t[:, 13]
|
| 317 |
+
t_sorted[:, 13] = t[:, 14]
|
| 318 |
+
t_sorted[:, 16] = t[:, 15]
|
| 319 |
+
t_sorted[:, 15] = t[:, 16]
|
| 320 |
+
# Icon corners
|
| 321 |
+
t_sorted[:, 20] = t[:, 17]
|
| 322 |
+
t_sorted[:, 19] = t[:, 18]
|
| 323 |
+
t_sorted[:, 18] = t[:, 19]
|
| 324 |
+
t_sorted[:, 17] = t[:, 20]
|
| 325 |
+
elif n == 0:
|
| 326 |
+
return t_sorted
|
| 327 |
+
|
| 328 |
+
return t_sorted
|
| 329 |
+
|
| 330 |
+
def __call__(self, sample, data_type, n):
|
| 331 |
+
if data_type == "tensor":
|
| 332 |
+
return self.rot_tensor(sample, n)
|
| 333 |
+
elif data_type == "points":
|
| 334 |
+
return self.rot_points(sample, n)
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
class RandomCropToSizeTorch(object):
|
| 338 |
+
def __init__(
|
| 339 |
+
self,
|
| 340 |
+
input_slice=[21, 1, 1],
|
| 341 |
+
size=(256, 256),
|
| 342 |
+
fill=(0, 0),
|
| 343 |
+
data_format="tensor",
|
| 344 |
+
dtype=torch.float32,
|
| 345 |
+
max_size=None,
|
| 346 |
+
):
|
| 347 |
+
self.size = size
|
| 348 |
+
self.width = size[0]
|
| 349 |
+
self.height = size[1]
|
| 350 |
+
self.dtype = dtype
|
| 351 |
+
self.fill = fill
|
| 352 |
+
self.max_size = max_size
|
| 353 |
+
self.input_slice = input_slice
|
| 354 |
+
|
| 355 |
+
if data_format == "dict":
|
| 356 |
+
self.augment = self.augment_dict
|
| 357 |
+
elif data_format == "tensor":
|
| 358 |
+
self.augment = self.augment_tesor
|
| 359 |
+
elif data_format == "dict furu":
|
| 360 |
+
self.augment = self.augment_dict_furu
|
| 361 |
+
|
| 362 |
+
def __call__(self, sample):
|
| 363 |
+
return self.augment(sample)
|
| 364 |
+
|
| 365 |
+
def augment_tesor(self, sample):
|
| 366 |
+
image, label = sample["image"], sample["label"]
|
| 367 |
+
img_w = image.shape[2]
|
| 368 |
+
img_h = image.shape[1]
|
| 369 |
+
pad_w = int(self.width / 2)
|
| 370 |
+
pad_h = int(self.height / 2)
|
| 371 |
+
|
| 372 |
+
new_w = self.width + max(img_w, self.width)
|
| 373 |
+
new_h = self.height + max(img_h, self.height)
|
| 374 |
+
|
| 375 |
+
new_image = torch.zeros([image.shape[0], new_h, new_w], dtype=self.dtype)
|
| 376 |
+
new_image[:, pad_h : img_h + pad_h, pad_w : img_w + pad_w] = image
|
| 377 |
+
|
| 378 |
+
new_heatmaps = torch.zeros([self.input_slice[0], new_h, new_w], dtype=self.dtype)
|
| 379 |
+
new_heatmaps[:, pad_h : img_h + pad_h, pad_w : img_w + pad_w] = label[: self.input_slice[0]]
|
| 380 |
+
|
| 381 |
+
new_rooms = torch.full((self.input_slice[1], new_h, new_w), self.fill[0])
|
| 382 |
+
new_rooms[:, pad_h : img_h + pad_h, pad_w : img_w + pad_w] = label[self.input_slice[0]]
|
| 383 |
+
new_icons = torch.full((self.input_slice[2], new_h, new_w), self.fill[1])
|
| 384 |
+
new_icons[:, pad_h : img_h + pad_h, pad_w : img_w + pad_w] = label[self.input_slice[0] + self.input_slice[1]]
|
| 385 |
+
|
| 386 |
+
label = torch.cat((new_heatmaps, new_rooms, new_icons), 0)
|
| 387 |
+
image = new_image
|
| 388 |
+
|
| 389 |
+
removed_up = random.randint(0, new_h - self.width)
|
| 390 |
+
removed_left = random.randint(0, new_w - self.height)
|
| 391 |
+
|
| 392 |
+
removed_down = new_h - self.height - removed_up
|
| 393 |
+
removed_right = new_w - self.width - removed_left
|
| 394 |
+
|
| 395 |
+
if removed_down == 0 and removed_right == 0:
|
| 396 |
+
image = image[:, removed_up:, removed_left:]
|
| 397 |
+
label = label[:, removed_up:, removed_left:]
|
| 398 |
+
elif removed_down == 0:
|
| 399 |
+
image = image[:, removed_up:, removed_left:-removed_right]
|
| 400 |
+
label = label[:, removed_up:, removed_left:-removed_right]
|
| 401 |
+
elif removed_right == 0:
|
| 402 |
+
image = image[:, removed_up:-removed_down, removed_left:]
|
| 403 |
+
label = label[:, removed_up:-removed_down, removed_left:]
|
| 404 |
+
else:
|
| 405 |
+
image = image[:, removed_up:-removed_down, removed_left:-removed_right]
|
| 406 |
+
label = label[:, removed_up:-removed_down, removed_left:-removed_right]
|
| 407 |
+
|
| 408 |
+
return {"image": image, "label": label}
|
| 409 |
+
|
| 410 |
+
def augment_dict(self, sample):
|
| 411 |
+
image, label = sample["image"], sample["label"]
|
| 412 |
+
heatmap_points = sample["heatmaps"]
|
| 413 |
+
img_w = image.shape[2]
|
| 414 |
+
img_h = image.shape[1]
|
| 415 |
+
pad_w = int(self.width / 2)
|
| 416 |
+
pad_h = int(self.height / 2)
|
| 417 |
+
|
| 418 |
+
new_w = self.width + img_w
|
| 419 |
+
new_h = self.height + img_h
|
| 420 |
+
|
| 421 |
+
new_image = torch.full([image.shape[0], new_h, new_w], 255)
|
| 422 |
+
new_image[:, pad_h : img_h + pad_h, pad_w : img_w + pad_w] = image
|
| 423 |
+
|
| 424 |
+
new_rooms = torch.full((1, new_h, new_w), self.fill[0])
|
| 425 |
+
new_rooms[:, pad_h : img_h + pad_h, pad_w : img_w + pad_w] = label[0]
|
| 426 |
+
new_icons = torch.full((1, new_h, new_w), self.fill[1])
|
| 427 |
+
new_icons[:, pad_h : img_h + pad_h, pad_w : img_w + pad_w] = label[1]
|
| 428 |
+
|
| 429 |
+
label = torch.cat((new_rooms, new_icons), 0)
|
| 430 |
+
image = new_image
|
| 431 |
+
|
| 432 |
+
removed_up = random.randint(0, new_h - self.width)
|
| 433 |
+
removed_left = random.randint(0, new_w - self.height)
|
| 434 |
+
|
| 435 |
+
removed_down = new_h - self.height - removed_up
|
| 436 |
+
removed_right = new_w - self.width - removed_left
|
| 437 |
+
|
| 438 |
+
new_heatmap_points = dict()
|
| 439 |
+
for junction_type, points in heatmap_points.items():
|
| 440 |
+
new_heatmap_points_per_type = []
|
| 441 |
+
for point in points:
|
| 442 |
+
new_heatmap_points_per_type.append([point[0] + pad_w, point[1] + pad_h])
|
| 443 |
+
|
| 444 |
+
new_heatmap_points[junction_type] = new_heatmap_points_per_type
|
| 445 |
+
|
| 446 |
+
heatmap_points = new_heatmap_points
|
| 447 |
+
|
| 448 |
+
if removed_down == 0 and removed_right == 0:
|
| 449 |
+
image = image[:, removed_up:, removed_left:]
|
| 450 |
+
label = label[:, removed_up:, removed_left:]
|
| 451 |
+
heatmap_points = clip_heatmaps(heatmap_points, removed_left, inf, removed_up, inf)
|
| 452 |
+
elif removed_down == 0:
|
| 453 |
+
image = image[:, removed_up:, removed_left:-removed_right]
|
| 454 |
+
label = label[:, removed_up:, removed_left:-removed_right]
|
| 455 |
+
heatmap_points = clip_heatmaps(heatmap_points, removed_left, removed_left + self.width, removed_up, inf)
|
| 456 |
+
elif removed_right == 0:
|
| 457 |
+
image = image[:, removed_up:-removed_down, removed_left:]
|
| 458 |
+
label = label[:, removed_up:-removed_down, removed_left:]
|
| 459 |
+
heatmap_points = clip_heatmaps(heatmap_points, removed_left, inf, removed_up, removed_up + self.width)
|
| 460 |
+
else:
|
| 461 |
+
image = image[:, removed_up:-removed_down, removed_left:-removed_right]
|
| 462 |
+
label = label[:, removed_up:-removed_down, removed_left:-removed_right]
|
| 463 |
+
heatmap_points = clip_heatmaps(
|
| 464 |
+
heatmap_points, removed_left, removed_left + self.width, removed_up, removed_up + self.height
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
return {"image": image, "label": label, "heatmaps": heatmap_points, "scale": sample["scale"]}
|
| 468 |
+
|
| 469 |
+
def augment_dict_furu(self, sample):
|
| 470 |
+
image, label = sample["image"], sample["label"]
|
| 471 |
+
heatmap_points = sample["heatmap_points"]
|
| 472 |
+
img_w = image.shape[2]
|
| 473 |
+
img_h = image.shape[1]
|
| 474 |
+
pad_w = int(self.width / 2)
|
| 475 |
+
pad_h = int(self.height / 2)
|
| 476 |
+
|
| 477 |
+
new_w = self.width + img_w
|
| 478 |
+
new_h = self.height + img_h
|
| 479 |
+
|
| 480 |
+
new_image = torch.full([image.shape[0], new_h, new_w], 255)
|
| 481 |
+
new_image[:, pad_h : img_h + pad_h, pad_w : img_w + pad_w] = image
|
| 482 |
+
|
| 483 |
+
new_rooms = torch.full((1, new_h, new_w), self.fill[0])
|
| 484 |
+
new_rooms[:, pad_h : img_h + pad_h, pad_w : img_w + pad_w] = label[0]
|
| 485 |
+
new_icons = torch.full((1, new_h, new_w), self.fill[1])
|
| 486 |
+
new_icons[:, pad_h : img_h + pad_h, pad_w : img_w + pad_w] = label[1]
|
| 487 |
+
|
| 488 |
+
label = torch.cat((new_rooms, new_icons), 0)
|
| 489 |
+
image = new_image
|
| 490 |
+
|
| 491 |
+
removed_up = random.randint(0, new_h - self.width)
|
| 492 |
+
removed_left = random.randint(0, new_w - self.height)
|
| 493 |
+
|
| 494 |
+
removed_down = new_h - self.height - removed_up
|
| 495 |
+
removed_right = new_w - self.width - removed_left
|
| 496 |
+
|
| 497 |
+
new_heatmap_points = dict()
|
| 498 |
+
for junction_type, points in heatmap_points.items():
|
| 499 |
+
new_heatmap_points_per_type = []
|
| 500 |
+
for point in points:
|
| 501 |
+
new_heatmap_points_per_type.append([point[0] + pad_w, point[1] + pad_h])
|
| 502 |
+
|
| 503 |
+
new_heatmap_points[junction_type] = new_heatmap_points_per_type
|
| 504 |
+
|
| 505 |
+
heatmap_points = new_heatmap_points
|
| 506 |
+
|
| 507 |
+
if removed_down == 0 and removed_right == 0:
|
| 508 |
+
image = image[:, removed_up:, removed_left:]
|
| 509 |
+
label = label[:, removed_up:, removed_left:]
|
| 510 |
+
heatmap_points = clip_heatmaps(heatmap_points, removed_left, inf, removed_up, inf)
|
| 511 |
+
elif removed_down == 0:
|
| 512 |
+
image = image[:, removed_up:, removed_left:-removed_right]
|
| 513 |
+
label = label[:, removed_up:, removed_left:-removed_right]
|
| 514 |
+
heatmap_points = clip_heatmaps(heatmap_points, removed_left, removed_left + self.width, removed_up, inf)
|
| 515 |
+
elif removed_right == 0:
|
| 516 |
+
image = image[:, removed_up:-removed_down, removed_left:]
|
| 517 |
+
label = label[:, removed_up:-removed_down, removed_left:]
|
| 518 |
+
heatmap_points = clip_heatmaps(heatmap_points, removed_left, inf, removed_up, removed_up + self.width)
|
| 519 |
+
else:
|
| 520 |
+
image = image[:, removed_up:-removed_down, removed_left:-removed_right]
|
| 521 |
+
label = label[:, removed_up:-removed_down, removed_left:-removed_right]
|
| 522 |
+
heatmap_points = clip_heatmaps(
|
| 523 |
+
heatmap_points, removed_left, removed_left + self.width, removed_up, removed_up + self.height
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
return {"image": image, "label": label, "heatmap_points": heatmap_points}
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
class ColorJitterTorch(object):
|
| 530 |
+
def __init__(self, b_var=0.4, c_var=0.4, s_var=0.4, dtype=torch.float32, version="dict"):
|
| 531 |
+
self.b_var = b_var
|
| 532 |
+
self.c_var = c_var
|
| 533 |
+
self.s_var = s_var
|
| 534 |
+
self.dtype = dtype
|
| 535 |
+
self.version = version
|
| 536 |
+
|
| 537 |
+
def __call__(self, sample):
|
| 538 |
+
res = sample
|
| 539 |
+
image = sample["image"]
|
| 540 |
+
image = self.brightness(image, self.b_var)
|
| 541 |
+
image = self.contrast(image, self.c_var)
|
| 542 |
+
image = self.saturation(image, self.s_var)
|
| 543 |
+
res["image"] = image
|
| 544 |
+
|
| 545 |
+
return res
|
| 546 |
+
|
| 547 |
+
def blend(self, img_1, img_2, var):
|
| 548 |
+
m = torch.tensor([0], dtype=self.dtype).uniform_(-var, var)
|
| 549 |
+
alpha = 1 + m
|
| 550 |
+
res = img_1 * alpha + (1 - alpha) * img_2
|
| 551 |
+
res = torch.clamp(res, min=0, max=255)
|
| 552 |
+
|
| 553 |
+
return res
|
| 554 |
+
|
| 555 |
+
def grayscale(self, img):
|
| 556 |
+
red = img[0] * 0.299
|
| 557 |
+
green = img[1] * 0.587
|
| 558 |
+
blue = img[2] * 0.114
|
| 559 |
+
gray = red + green + blue
|
| 560 |
+
gray = torch.clamp(gray, min=0, max=255)
|
| 561 |
+
res = torch.stack((gray, gray, gray), dim=0)
|
| 562 |
+
|
| 563 |
+
return res
|
| 564 |
+
|
| 565 |
+
def saturation(self, img, var):
|
| 566 |
+
res = self.grayscale(img)
|
| 567 |
+
res = self.blend(img, res, var)
|
| 568 |
+
|
| 569 |
+
return res
|
| 570 |
+
|
| 571 |
+
def brightness(self, img, var):
|
| 572 |
+
res = torch.zeros(img.shape)
|
| 573 |
+
res = self.blend(img, res, var)
|
| 574 |
+
|
| 575 |
+
return res
|
| 576 |
+
|
| 577 |
+
def contrast(self, img, var):
|
| 578 |
+
res = self.grayscale(img)
|
| 579 |
+
mean_color = res.mean()
|
| 580 |
+
res = torch.full(res.shape, mean_color)
|
| 581 |
+
res = self.blend(img, res, var)
|
| 582 |
+
|
| 583 |
+
return res
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
class ResizePaddedTorch(object):
|
| 587 |
+
def __init__(self, fill, size=(256, 256), both=True, dtype=torch.float32, data_format="tensor"):
|
| 588 |
+
self.size = size
|
| 589 |
+
self.width = size[0]
|
| 590 |
+
self.height = size[1]
|
| 591 |
+
self.both = both
|
| 592 |
+
self.dtype = dtype
|
| 593 |
+
self.fill = fill
|
| 594 |
+
self.fill_cval = 255
|
| 595 |
+
if data_format == "tensor":
|
| 596 |
+
self.augment = self.augment_tensor
|
| 597 |
+
elif data_format == "dict furu":
|
| 598 |
+
self.augment = self.augment_dict_furu
|
| 599 |
+
elif data_format == "dict":
|
| 600 |
+
self.augment = self.augment_dict
|
| 601 |
+
self.fill_cval = 1
|
| 602 |
+
|
| 603 |
+
def __call__(self, sample):
|
| 604 |
+
# image 1: Bi-cubic interpolation as in original paper
|
| 605 |
+
image, _, _, _ = self.resize_padded(
|
| 606 |
+
sample["image"], self.size, fill_cval=self.fill_cval, image=True, mode="bilinear", aling_corners=False
|
| 607 |
+
)
|
| 608 |
+
sample["image"] = image
|
| 609 |
+
|
| 610 |
+
return self.augment(sample)
|
| 611 |
+
|
| 612 |
+
def augment_tensor(self, sample):
|
| 613 |
+
image, label = sample["image"], sample["label"]
|
| 614 |
+
|
| 615 |
+
if self.both:
|
| 616 |
+
# labels 0: Nearest-neighbor interpolation
|
| 617 |
+
heatmaps, _, _, _ = self.resize_padded(label[:21], self.size, mode="bilinear", aling_corners=False)
|
| 618 |
+
rooms_padded, _, _, _ = self.resize_padded(label[[21]], self.size, mode="nearest", fill_cval=self.fill[0])
|
| 619 |
+
icons_padded, _, _, _ = self.resize_padded(
|
| 620 |
+
label[[22]],
|
| 621 |
+
self.size,
|
| 622 |
+
mode="nearest",
|
| 623 |
+
fill_cval=self.fill[1],
|
| 624 |
+
)
|
| 625 |
+
label = torch.cat((heatmaps, rooms_padded, icons_padded), dim=0)
|
| 626 |
+
|
| 627 |
+
return {"image": image, "label": label}
|
| 628 |
+
|
| 629 |
+
def augment_dict_furu(self, sample):
|
| 630 |
+
image, label = sample["image"], sample["label"]
|
| 631 |
+
heatmap_points = sample["heatmap_points"]
|
| 632 |
+
|
| 633 |
+
rooms_padded, _, _, _ = self.resize_padded(label[[0]], self.size, mode="nearest", fill_cval=self.fill[0])
|
| 634 |
+
icons_padded, ratio, y_pad, x_pad = self.resize_padded(
|
| 635 |
+
label[[1]], self.size, mode="nearest", fill_cval=self.fill[1]
|
| 636 |
+
)
|
| 637 |
+
label = torch.cat((rooms_padded, icons_padded), dim=0)
|
| 638 |
+
|
| 639 |
+
new_heatmap_points = dict()
|
| 640 |
+
for junction_type, points in heatmap_points.items():
|
| 641 |
+
new_heatmap_points_per_type = []
|
| 642 |
+
for point in points:
|
| 643 |
+
# Indexing starts from 0 but when we multiply with the ratio we need to start from 0.
|
| 644 |
+
new_x = point[0] * ratio + x_pad
|
| 645 |
+
new_y = point[1] * ratio + y_pad
|
| 646 |
+
new_heatmap_points_per_type.append([new_x, new_y])
|
| 647 |
+
new_heatmap_points[junction_type] = new_heatmap_points_per_type
|
| 648 |
+
|
| 649 |
+
heatmap_points = new_heatmap_points
|
| 650 |
+
|
| 651 |
+
return {"image": image, "label": label, "heatmap_points": heatmap_points}
|
| 652 |
+
|
| 653 |
+
def augment_dict(self, sample):
|
| 654 |
+
image, label = sample["image"], sample["label"]
|
| 655 |
+
heatmap_points = sample["heatmaps"]
|
| 656 |
+
scale = sample["scale"]
|
| 657 |
+
|
| 658 |
+
rooms_padded, _, _, _ = self.resize_padded(label[[0]], self.size, mode="nearest", fill_cval=self.fill[0])
|
| 659 |
+
icons_padded, ratio, y_pad, x_pad = self.resize_padded(
|
| 660 |
+
label[[1]], self.size, mode="nearest", fill_cval=self.fill[1]
|
| 661 |
+
)
|
| 662 |
+
label = torch.cat((rooms_padded, icons_padded), dim=0)
|
| 663 |
+
|
| 664 |
+
new_heatmap_points = dict()
|
| 665 |
+
for junction_type, points in heatmap_points.items():
|
| 666 |
+
new_heatmap_points_per_type = []
|
| 667 |
+
for point in points:
|
| 668 |
+
# Indexing starts from 0 but when we multiply with the ratio we need to start from 0.
|
| 669 |
+
new_x = point[0] * ratio + x_pad
|
| 670 |
+
new_y = point[1] * ratio + y_pad
|
| 671 |
+
if new_y < 256 and new_x < 256 and new_y >= 0 and new_x >= 0:
|
| 672 |
+
# __import__('ipdb').set_trace()
|
| 673 |
+
new_heatmap_points_per_type.append([new_x, new_y])
|
| 674 |
+
new_heatmap_points[junction_type] = new_heatmap_points_per_type
|
| 675 |
+
|
| 676 |
+
heatmap_points = new_heatmap_points
|
| 677 |
+
|
| 678 |
+
return {"image": image, "label": label, "heatmaps": heatmap_points, "scale": scale}
|
| 679 |
+
|
| 680 |
+
def resize_padded(self, img, new_shape, image=False, fill_cval=0, mode="nearest", aling_corners=None):
|
| 681 |
+
new_shape = torch.tensor([img.shape[0], new_shape[0], new_shape[1]], dtype=self.dtype)
|
| 682 |
+
old_shape = torch.tensor(img.shape, dtype=self.dtype)
|
| 683 |
+
|
| 684 |
+
ratio = (new_shape / old_shape).min()
|
| 685 |
+
img_s = torch.tensor(img.shape[1:], dtype=self.dtype)
|
| 686 |
+
interm_shape = (ratio * img_s).ceil()
|
| 687 |
+
|
| 688 |
+
interm_shape = [interm_shape[0], interm_shape[1]]
|
| 689 |
+
|
| 690 |
+
img = img.unsqueeze(0)
|
| 691 |
+
interm_img = torch.nn.functional.interpolate(img, size=interm_shape, mode=mode, align_corners=aling_corners)
|
| 692 |
+
interm_img = interm_img.squeeze(0)
|
| 693 |
+
|
| 694 |
+
a = (interm_img.shape[0], self.size[0], self.size[1])
|
| 695 |
+
|
| 696 |
+
new_img = torch.full(a, fill_cval)
|
| 697 |
+
|
| 698 |
+
x_pad = int((self.width - interm_img.shape[1]) / 2)
|
| 699 |
+
y_pad = int((self.height - interm_img.shape[2]) / 2)
|
| 700 |
+
|
| 701 |
+
new_img[:, x_pad : interm_img.shape[1] + x_pad, y_pad : interm_img.shape[2] + y_pad] = interm_img
|
| 702 |
+
|
| 703 |
+
return new_img, ratio, x_pad, y_pad
|
data_preprocess/cubicasa5k/combine_json.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import shutil
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def combine_json_files(input_pattern, data_path, split_type, output_file, start_image_id=0):
|
| 9 |
+
"""
|
| 10 |
+
Combines multiple COCO-style JSON annotation files into a single file.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
input_pattern: Glob pattern to match the input JSON files (e.g., "annotations/*.json")
|
| 14 |
+
output_file: Path to the output combined JSON file
|
| 15 |
+
"""
|
| 16 |
+
# Initialize combined data structure
|
| 17 |
+
combined_data = {"images": [], "annotations": [], "categories": []}
|
| 18 |
+
|
| 19 |
+
# Track image and annotation IDs to avoid duplicates
|
| 20 |
+
annotation_ids_seen = set()
|
| 21 |
+
|
| 22 |
+
next_image_id = start_image_id
|
| 23 |
+
next_annotation_id = 0
|
| 24 |
+
skip_file_list = []
|
| 25 |
+
image_id_mapping = {}
|
| 26 |
+
|
| 27 |
+
# Find all matching JSON files
|
| 28 |
+
json_files = sorted(glob.glob(input_pattern))
|
| 29 |
+
print(f"Found {len(json_files)} JSON files to combine")
|
| 30 |
+
|
| 31 |
+
# Process each file
|
| 32 |
+
for i, json_file in enumerate(json_files):
|
| 33 |
+
print(f"Processing file {i + 1}/{len(json_files)}: {json_file}")
|
| 34 |
+
|
| 35 |
+
with open(json_file, "r") as f:
|
| 36 |
+
data = json.load(f)
|
| 37 |
+
|
| 38 |
+
# Store categories from the first file
|
| 39 |
+
if i == 0 and data.get("categories"):
|
| 40 |
+
combined_data["categories"] = data["categories"]
|
| 41 |
+
|
| 42 |
+
# empty annos
|
| 43 |
+
if len(data["annotations"]) == 0:
|
| 44 |
+
skip_file_list.append(data["images"][0]["id"])
|
| 45 |
+
continue
|
| 46 |
+
|
| 47 |
+
# Process images
|
| 48 |
+
for image in data.get("images", []):
|
| 49 |
+
if image["id"] not in image_id_mapping:
|
| 50 |
+
image_id_mapping[image["id"]] = next_image_id
|
| 51 |
+
else:
|
| 52 |
+
skip_file_list.append(image["id"])
|
| 53 |
+
continue
|
| 54 |
+
image["id"] = next_image_id
|
| 55 |
+
next_image_id += 1
|
| 56 |
+
image["file_name"] = str(image["id"]).zfill(5) + ".png"
|
| 57 |
+
org_file_name = os.path.basename(json_file).replace(".json", ".png")
|
| 58 |
+
if image["file_name"] != org_file_name and os.path.exists(f"{data_path}/{split_type}/{org_file_name}"):
|
| 59 |
+
shutil.move(
|
| 60 |
+
f"{data_path}/{split_type}/{org_file_name}", f"{data_path}/{split_type}/{image['file_name']}"
|
| 61 |
+
)
|
| 62 |
+
combined_data["images"].append(image)
|
| 63 |
+
|
| 64 |
+
# Process annotations
|
| 65 |
+
for annotation in data.get("annotations", []):
|
| 66 |
+
annotation["id"] = next_annotation_id
|
| 67 |
+
next_annotation_id += 1
|
| 68 |
+
annotation["image_id"] = image_id_mapping[annotation["image_id"]]
|
| 69 |
+
|
| 70 |
+
annotation_ids_seen.add(annotation["id"])
|
| 71 |
+
combined_data["annotations"].append(annotation)
|
| 72 |
+
|
| 73 |
+
# Write combined data to output file
|
| 74 |
+
output_path = Path(output_file)
|
| 75 |
+
output_path.parent.mkdir(exist_ok=True, parents=True)
|
| 76 |
+
|
| 77 |
+
with open(output_file, "w") as f:
|
| 78 |
+
json.dump(combined_data, f, indent=2)
|
| 79 |
+
|
| 80 |
+
with open(output_path.parent / f"{output_path.name.split('.')[0]}_image_id_mapping.json", "w") as f:
|
| 81 |
+
json.dump(image_id_mapping, f, indent=2)
|
| 82 |
+
|
| 83 |
+
if len(skip_file_list):
|
| 84 |
+
with open(output_path.parent / f"{output_path.name.split('.')[0]}_skipped.txt", "w") as f:
|
| 85 |
+
f.write("\n".join([str(x) for x in skip_file_list]))
|
| 86 |
+
|
| 87 |
+
print(f"Combined data written to {output_file}")
|
| 88 |
+
print(f"Total images: {len(combined_data['images'])}")
|
| 89 |
+
print(f"Total annotations: {len(combined_data['annotations'])}")
|
| 90 |
+
print(f"Total categories: {len(combined_data['categories'])}")
|
| 91 |
+
print(f"Skipped images: {len(skip_file_list)}")
|
| 92 |
+
|
| 93 |
+
return combined_data
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
if __name__ == "__main__":
|
| 97 |
+
import argparse
|
| 98 |
+
|
| 99 |
+
parser = argparse.ArgumentParser(description="Combine multiple COCO-style JSON annotation files")
|
| 100 |
+
parser.add_argument("--input", required=True, help="Glob pattern for input JSON files, e.g., 'annotations/*.json'")
|
| 101 |
+
parser.add_argument("--output", required=True, help="Output JSON file path")
|
| 102 |
+
|
| 103 |
+
args = parser.parse_args()
|
| 104 |
+
|
| 105 |
+
splits = ["train", "val", "test"]
|
| 106 |
+
for i, split in enumerate(splits):
|
| 107 |
+
if split == "train":
|
| 108 |
+
start_image_id = 0
|
| 109 |
+
else:
|
| 110 |
+
start_image_id += len(list(Path(f"{args.input}/{splits[i - 1]}").glob("*.png")))
|
| 111 |
+
|
| 112 |
+
combine_json_files(
|
| 113 |
+
f"{args.input}/annotations_json/{split}/*.json",
|
| 114 |
+
args.input,
|
| 115 |
+
split,
|
| 116 |
+
f"{args.output}/{split}.json",
|
| 117 |
+
start_image_id=start_image_id,
|
| 118 |
+
)
|
data_preprocess/cubicasa5k/create_coco_cc5k.py
ADDED
|
@@ -0,0 +1,672 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
from multiprocessing import Pool
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
import cv2
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
import numpy as np
|
| 11 |
+
from loaders import FloorplanSVG
|
| 12 |
+
from matplotlib.patches import Patch
|
| 13 |
+
from PIL import Image
|
| 14 |
+
from shapely.geometry import Polygon
|
| 15 |
+
from skimage import measure
|
| 16 |
+
from tqdm import tqdm
|
| 17 |
+
|
| 18 |
+
sys.path.append(str(Path(__file__).resolve().parent.parent.parent))
|
| 19 |
+
|
| 20 |
+
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
| 21 |
+
from common_utils import resort_corners
|
| 22 |
+
from stru3d.stru3d_utils import type2id
|
| 23 |
+
|
| 24 |
+
#### ORIGINAL ROOM NAMES & ICON_NAMES ####
|
| 25 |
+
ROOM_NAMES = {
|
| 26 |
+
0: "Background",
|
| 27 |
+
1: "Outdoor",
|
| 28 |
+
2: "Wall",
|
| 29 |
+
3: "Kitchen",
|
| 30 |
+
4: "Living Room",
|
| 31 |
+
5: "Bed Room",
|
| 32 |
+
6: "Bath",
|
| 33 |
+
7: "Entry",
|
| 34 |
+
8: "Railing",
|
| 35 |
+
9: "Storage",
|
| 36 |
+
10: "Garage",
|
| 37 |
+
11: "Undefined",
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
ICON_NAMES = {
|
| 41 |
+
0: "No Icon",
|
| 42 |
+
1: "Window",
|
| 43 |
+
2: "Door",
|
| 44 |
+
3: "Closet",
|
| 45 |
+
4: "Electrical Applience",
|
| 46 |
+
5: "Toilet",
|
| 47 |
+
6: "Sink",
|
| 48 |
+
7: "Sauna Bench",
|
| 49 |
+
8: "Fire Place",
|
| 50 |
+
9: "Bathtub",
|
| 51 |
+
10: "Chimney",
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
CC5K_2_S3D_MAPPING = {
|
| 56 |
+
0: None, # "Background"
|
| 57 |
+
1: type2id["balcony"], # "Outdoor" -> balcony (4)
|
| 58 |
+
2: None, # "Wall" has no direct match
|
| 59 |
+
3: type2id["kitchen"], # Kitchen -> kitchen (1)
|
| 60 |
+
4: type2id["living room"], # Living Room -> living room (0)
|
| 61 |
+
5: type2id["bedroom"], # Bed Room -> bedroom (2)
|
| 62 |
+
6: type2id["bathroom"], # Bath -> bathroom (3)
|
| 63 |
+
7: 18, # 'Entry' has no direct match
|
| 64 |
+
8: 19, # "Railing" has no direct match
|
| 65 |
+
9: type2id["store room"], # Storage -> store room (9)
|
| 66 |
+
10: type2id["garage"], # Garage -> garage (14)
|
| 67 |
+
11: type2id["undefined"], # Undefined -> undefined (15)
|
| 68 |
+
12: type2id["window"], # Window -> window (17)
|
| 69 |
+
13: type2id["door"], # Door -> door (16)
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
CC5K_MAPPING = {
|
| 73 |
+
0: None,
|
| 74 |
+
1: 0, # Outdoor
|
| 75 |
+
2: 1, # Wall
|
| 76 |
+
3: 2, # Kitchen
|
| 77 |
+
4: 3, # Living Room
|
| 78 |
+
5: 4, # Bed Room
|
| 79 |
+
6: 5, # Bath
|
| 80 |
+
7: 6, # Entry
|
| 81 |
+
8: 1, # Railing -> Wall
|
| 82 |
+
9: 7, # Storage
|
| 83 |
+
10: 8, # Garage
|
| 84 |
+
11: 9, # Undefined
|
| 85 |
+
12: 10, # Window
|
| 86 |
+
13: 11, # Door
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
CC5K_MAPPING_2 = {
|
| 90 |
+
0: None,
|
| 91 |
+
1: 0, # Outdoor
|
| 92 |
+
2: None, # Wall
|
| 93 |
+
3: 1, # Kitchen
|
| 94 |
+
4: 2, # Living Room
|
| 95 |
+
5: 3, # Bed Room
|
| 96 |
+
6: 4, # Bath
|
| 97 |
+
7: 5, # Entry
|
| 98 |
+
8: None, # Railing -> Wall
|
| 99 |
+
9: 6, # Storage
|
| 100 |
+
10: 7, # Garage
|
| 101 |
+
11: 8, # Undefined
|
| 102 |
+
12: 9, # Window
|
| 103 |
+
13: 10, # Door
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
CC5K_CLASS_MAPPING = {
|
| 107 |
+
"Outdoor": 0,
|
| 108 |
+
"Wall, Railing": 1,
|
| 109 |
+
"Kitchen": 2,
|
| 110 |
+
"Living Room": 3,
|
| 111 |
+
"Bed Room": 4,
|
| 112 |
+
"Bath": 5,
|
| 113 |
+
"Entry": 6,
|
| 114 |
+
"Storage": 7,
|
| 115 |
+
"Garage": 8,
|
| 116 |
+
"Undefined": 9,
|
| 117 |
+
"Window": 10,
|
| 118 |
+
"Door": 11,
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
CC5K_CLASS_MAPPING_2 = {
|
| 122 |
+
"Outdoor": 0,
|
| 123 |
+
"Kitchen": 1,
|
| 124 |
+
"Living Room": 2,
|
| 125 |
+
"Bed Room": 3,
|
| 126 |
+
"Bath": 4,
|
| 127 |
+
"Entry": 5,
|
| 128 |
+
"Storage": 6,
|
| 129 |
+
"Garage": 7,
|
| 130 |
+
"Undefined": 8,
|
| 131 |
+
"Window": 9,
|
| 132 |
+
"Door": 10,
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
CLASS_MAPPING = {
|
| 136 |
+
"living room": 0,
|
| 137 |
+
"kitchen": 1,
|
| 138 |
+
"bedroom": 2,
|
| 139 |
+
"bathroom": 3,
|
| 140 |
+
"balcony": 4,
|
| 141 |
+
"corridor": 5,
|
| 142 |
+
"dining room": 6,
|
| 143 |
+
"study": 7,
|
| 144 |
+
"studio": 8,
|
| 145 |
+
"store room": 9,
|
| 146 |
+
"garden": 10,
|
| 147 |
+
"laundry room": 11,
|
| 148 |
+
"office": 12,
|
| 149 |
+
"basement": 13,
|
| 150 |
+
"garage": 14,
|
| 151 |
+
"undefined": 15,
|
| 152 |
+
"door": 16,
|
| 153 |
+
"window": 17,
|
| 154 |
+
"entry": 18,
|
| 155 |
+
"railing": 19,
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def fill_holes_in_mask(binary_mask):
|
| 160 |
+
"""
|
| 161 |
+
Fill 0-pixels in a binary mask that are completely surrounded by 1-pixels.
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
binary_mask (numpy.ndarray): Binary mask with 0 and 1 values.
|
| 165 |
+
|
| 166 |
+
Returns:
|
| 167 |
+
numpy.ndarray: Binary mask with holes filled.
|
| 168 |
+
"""
|
| 169 |
+
# Ensure the mask is binary (0 and 1)
|
| 170 |
+
binary_mask = (binary_mask > 0).astype(np.uint8)
|
| 171 |
+
|
| 172 |
+
# Apply dilation
|
| 173 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (15, 15))
|
| 174 |
+
binary_mask = cv2.dilate(binary_mask, kernel, iterations=1)
|
| 175 |
+
|
| 176 |
+
# Find contours in the mask
|
| 177 |
+
contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 178 |
+
|
| 179 |
+
# Fill the contours
|
| 180 |
+
filled_mask = binary_mask.copy()
|
| 181 |
+
cv2.fillPoly(filled_mask, contours, 1)
|
| 182 |
+
|
| 183 |
+
return filled_mask
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def close_contour(contour):
|
| 187 |
+
if not np.array_equal(contour[0], contour[-1]):
|
| 188 |
+
contour = np.vstack((contour, contour[0]))
|
| 189 |
+
return contour
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def binary_mask_to_polygon(binary_mask, tolerance=0):
|
| 193 |
+
"""Converts a binary mask to COCO polygon representation
|
| 194 |
+
Ref: https://github.com/waspinator/pycococreator/blob/master/pycococreatortools/pycococreatortools.py
|
| 195 |
+
|
| 196 |
+
Args:
|
| 197 |
+
binary_mask: a 2D binary numpy array where '1's represent the object
|
| 198 |
+
tolerance: Maximum distance from original points of polygon to approximated
|
| 199 |
+
polygonal chain. If tolerance is 0, the original coordinate array is returned.
|
| 200 |
+
|
| 201 |
+
"""
|
| 202 |
+
polygons = []
|
| 203 |
+
# pad mask to close contours of shapes which start and end at an edge
|
| 204 |
+
padded_binary_mask = np.pad(binary_mask, pad_width=1, mode="constant", constant_values=0)
|
| 205 |
+
contours = measure.find_contours(padded_binary_mask, 0.5)
|
| 206 |
+
contours = np.subtract(contours, 1)
|
| 207 |
+
for contour in contours:
|
| 208 |
+
contour = close_contour(contour)
|
| 209 |
+
contour = measure.approximate_polygon(contour, tolerance)
|
| 210 |
+
if len(contour) < 3:
|
| 211 |
+
continue
|
| 212 |
+
contour = np.flip(contour, axis=1)
|
| 213 |
+
segmentation = contour.ravel().tolist()
|
| 214 |
+
# after padding and subtracting 1 we may get -0.5 points in our segmentation
|
| 215 |
+
segmentation = [0 if i < 0 else i for i in segmentation]
|
| 216 |
+
polygons.append(segmentation)
|
| 217 |
+
|
| 218 |
+
return polygons
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def extract_icon_cv2(mask, start_cls_id=11, skip_classes=[]):
|
| 222 |
+
room_ids = np.unique(mask)
|
| 223 |
+
room_polygons = []
|
| 224 |
+
new_mask = np.zeros(mask.shape)
|
| 225 |
+
|
| 226 |
+
# window, door
|
| 227 |
+
for room_id in room_ids:
|
| 228 |
+
if room_id in skip_classes:
|
| 229 |
+
continue
|
| 230 |
+
true_room_id = int(room_id) + start_cls_id
|
| 231 |
+
# Create binary mask for this room
|
| 232 |
+
room_mask = (mask == room_id).astype(np.uint8)
|
| 233 |
+
new_mask = np.where(room_mask, true_room_id, 0)
|
| 234 |
+
|
| 235 |
+
# Find contours using OpenCV
|
| 236 |
+
contours, _ = cv2.findContours(room_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 237 |
+
|
| 238 |
+
if contours:
|
| 239 |
+
# # Get the largest contour
|
| 240 |
+
# largest_contour = max(contours, key=cv2.contourArea)
|
| 241 |
+
for cnt in contours:
|
| 242 |
+
polygon = [tuple(point[0]) for point in cnt]
|
| 243 |
+
if len(polygon) < 3:
|
| 244 |
+
continue
|
| 245 |
+
|
| 246 |
+
poly = Polygon(polygon)
|
| 247 |
+
simplified_poly = poly.simplify(tolerance=0.5, preserve_topology=True)
|
| 248 |
+
simplified_poly = list(simplified_poly.exterior.coords)
|
| 249 |
+
room_polygons.append([simplified_poly, true_room_id])
|
| 250 |
+
|
| 251 |
+
return room_polygons, new_mask
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def visualize_room_polygons(mask, room_polygons, class_names, save_path="cubicasa_debug.png", bg_polygons=None):
|
| 255 |
+
"""
|
| 256 |
+
Visualize the extracted room polygons.
|
| 257 |
+
|
| 258 |
+
Args:
|
| 259 |
+
mask: Original segmentation mask
|
| 260 |
+
room_polygons: Dictionary of room polygons as returned by extract_room_polygons
|
| 261 |
+
figsize: Figure size for the plot
|
| 262 |
+
"""
|
| 263 |
+
# Set figure size to exactly 256x256 pixels
|
| 264 |
+
dpi = 100 # Standard screen DPI
|
| 265 |
+
figsize = (mask.shape[1] / dpi, mask.shape[0] / dpi) # Convert pixels to inches
|
| 266 |
+
|
| 267 |
+
# Get unique classes from the mask
|
| 268 |
+
unique_classes = np.unique(mask)
|
| 269 |
+
|
| 270 |
+
# Create a discrete colormap
|
| 271 |
+
cmap = plt.cm.get_cmap("gist_ncar", 256) # nipy_spectral
|
| 272 |
+
# cmap = ListedColormap([cmap(x) for x in np.linspace(0, 1, int(20))])
|
| 273 |
+
|
| 274 |
+
fig = plt.figure(figsize=figsize)
|
| 275 |
+
ax = fig.add_axes([0, 0, 1, 1])
|
| 276 |
+
plt.imshow(mask, cmap=cmap, interpolation="nearest", alpha=0.6, vmin=0, vmax=20)
|
| 277 |
+
|
| 278 |
+
# Plot each room polygon
|
| 279 |
+
for polygon, room_cls in room_polygons:
|
| 280 |
+
polygon_array = np.array(polygon).copy()
|
| 281 |
+
# # flip y
|
| 282 |
+
# polygon_array[:, 1] = mask.shape[0] - polygon_array[:, 1] - 1
|
| 283 |
+
ax.plot(polygon_array[:, 0], polygon_array[:, 1], "k-", linewidth=2)
|
| 284 |
+
|
| 285 |
+
# Add room ID label at the centroid
|
| 286 |
+
centroid_x = np.mean(polygon_array[:, 0])
|
| 287 |
+
centroid_y = np.mean(polygon_array[:, 1])
|
| 288 |
+
ax.text(
|
| 289 |
+
centroid_x,
|
| 290 |
+
centroid_y,
|
| 291 |
+
str(room_cls),
|
| 292 |
+
fontsize=12,
|
| 293 |
+
ha="center",
|
| 294 |
+
va="center",
|
| 295 |
+
bbox=dict(facecolor="white", alpha=0.7),
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
if bg_polygons is not None:
|
| 299 |
+
# Plot each room polygon
|
| 300 |
+
for polygon, room_cls in bg_polygons:
|
| 301 |
+
polygon_array = np.array(polygon).copy()
|
| 302 |
+
# # flip y
|
| 303 |
+
# polygon_array[:, 1] = mask.shape[0] - polygon_array[:, 1] - 1
|
| 304 |
+
ax.plot(polygon_array[:, 0], polygon_array[:, 1], "c-", linewidth=2)
|
| 305 |
+
|
| 306 |
+
# Create custom legend elements
|
| 307 |
+
legend_elements = []
|
| 308 |
+
norm = np.linspace(0, 1, 21) # int(max(unique_classes))+1
|
| 309 |
+
|
| 310 |
+
for i, cls in enumerate(sorted(unique_classes)):
|
| 311 |
+
# if int(cls) == 0:
|
| 312 |
+
# continue
|
| 313 |
+
# Get the exact same color that imshow uses
|
| 314 |
+
color = cmap(norm[int(cls)])
|
| 315 |
+
# color = cmap(int(cls))
|
| 316 |
+
|
| 317 |
+
cls_name = f"{int(cls)}_{class_names[int(cls)]}"
|
| 318 |
+
# You can replace f"Class {cls}" with your actual class names if available
|
| 319 |
+
legend_elements.append(Patch(facecolor=color, edgecolor="black", label=f"{cls_name}", alpha=0.6))
|
| 320 |
+
|
| 321 |
+
# Add the legend to the plot
|
| 322 |
+
ax.legend(
|
| 323 |
+
handles=legend_elements,
|
| 324 |
+
loc="best",
|
| 325 |
+
title="Classes",
|
| 326 |
+
fontsize=20,
|
| 327 |
+
markerscale=4,
|
| 328 |
+
title_fontsize=28,
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
# plt.title('Room Polygons Extracted from Segmentation Mask')
|
| 332 |
+
plt.axis("equal")
|
| 333 |
+
plt.axis("off")
|
| 334 |
+
fig.savefig(save_path, bbox_inches="tight", pad_inches=0)
|
| 335 |
+
plt.close()
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def config():
|
| 339 |
+
a = argparse.ArgumentParser(description="Generate coco format data for Structured3D")
|
| 340 |
+
a.add_argument(
|
| 341 |
+
"--data_root", default="Structured3D_panorama", type=str, help="path to raw Structured3D_panorama folder"
|
| 342 |
+
)
|
| 343 |
+
a.add_argument("--output", default="coco_cubicasa5k", type=str, help="path to output folder")
|
| 344 |
+
a.add_argument("--disable_wd2line", action="store_true")
|
| 345 |
+
|
| 346 |
+
args = a.parse_args()
|
| 347 |
+
return args
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
def save_image(image_path, output_path, mask=None):
|
| 351 |
+
"""
|
| 352 |
+
ref: https://github.com/ultralytics/ultralytics/issues/339
|
| 353 |
+
"""
|
| 354 |
+
img = Image.open(image_path).convert("RGB")
|
| 355 |
+
img.info.pop("icc_profile", None)
|
| 356 |
+
|
| 357 |
+
if mask is not None:
|
| 358 |
+
img_array = np.array(img)
|
| 359 |
+
if len(mask.shape) == 2 and len(img_array.shape) == 3:
|
| 360 |
+
mask = mask[:, :, np.newaxis]
|
| 361 |
+
masked_img = np.where(mask == 0, 255, img_array)
|
| 362 |
+
img = Image.fromarray(masked_img.astype(np.uint8))
|
| 363 |
+
|
| 364 |
+
img.save(output_path)
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def remove_polygons_by_type(polygons, skip_types=[]):
|
| 368 |
+
new_room_polygons = []
|
| 369 |
+
for polygon, poly_type in polygons:
|
| 370 |
+
if poly_type in skip_types:
|
| 371 |
+
continue
|
| 372 |
+
new_room_polygons.append([polygon, poly_type])
|
| 373 |
+
return new_room_polygons
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
def merge_rooms_and_icons(room_polygons, icon_polygons):
|
| 377 |
+
new_icon_polygons = []
|
| 378 |
+
for poly, poly_type in icon_polygons:
|
| 379 |
+
new_icon_polygons.append([poly, poly_type + 11])
|
| 380 |
+
|
| 381 |
+
return room_polygons + new_icon_polygons
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
def create_coco_bounding_box(bb_x, bb_y, image_width, image_height, bound_pad=2):
|
| 385 |
+
bb_x = np.unique(bb_x)
|
| 386 |
+
bb_y = np.unique(bb_y)
|
| 387 |
+
bb_x_min = np.maximum(np.min(bb_x) - bound_pad, 0)
|
| 388 |
+
bb_y_min = np.maximum(np.min(bb_y) - bound_pad, 0)
|
| 389 |
+
|
| 390 |
+
bb_x_max = np.minimum(np.max(bb_x) + bound_pad, image_width - 1)
|
| 391 |
+
bb_y_max = np.minimum(np.max(bb_y) + bound_pad, image_height - 1)
|
| 392 |
+
|
| 393 |
+
bb_width = bb_x_max - bb_x_min
|
| 394 |
+
bb_height = bb_y_max - bb_y_min
|
| 395 |
+
|
| 396 |
+
coco_bb = [bb_x_min, bb_y_min, bb_width, bb_height]
|
| 397 |
+
return coco_bb
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
def process_floorplan(
|
| 401 |
+
image_set,
|
| 402 |
+
scene_id,
|
| 403 |
+
start_scene_id,
|
| 404 |
+
args,
|
| 405 |
+
save_dir,
|
| 406 |
+
annos_folder,
|
| 407 |
+
use_org_cc5k_classs=False,
|
| 408 |
+
vis_fp=False,
|
| 409 |
+
wd2line=False,
|
| 410 |
+
):
|
| 411 |
+
if use_org_cc5k_classs:
|
| 412 |
+
class_mapping_dict = CC5K_MAPPING_2 # old: CC5K_MAPPING
|
| 413 |
+
class_to_index_dict = CC5K_CLASS_MAPPING_2
|
| 414 |
+
door_window_index = [10, 9]
|
| 415 |
+
else:
|
| 416 |
+
class_mapping_dict = CC5K_2_S3D_MAPPING
|
| 417 |
+
class_to_index_dict = CLASS_MAPPING
|
| 418 |
+
door_window_index = [16, 17]
|
| 419 |
+
|
| 420 |
+
mask = image_set["label"].numpy()
|
| 421 |
+
room_polygons = [[poly, poly_type] for poly, poly_type in zip(image_set["room_polygon"], image_set["room_type"])]
|
| 422 |
+
icon_polygons = [[poly, poly_type] for poly, poly_type in zip(image_set["icon_polygon"], image_set["icon_type"])]
|
| 423 |
+
|
| 424 |
+
image_height, image_width = mask.shape[1:]
|
| 425 |
+
coco_annotation_dict_list = []
|
| 426 |
+
|
| 427 |
+
# for storing
|
| 428 |
+
save_dict = prepare_dict(class_to_index_dict) # old: CC5K_CLASS_MAPPING
|
| 429 |
+
|
| 430 |
+
instance_id = 0
|
| 431 |
+
img_id = int(scene_id) + start_scene_id
|
| 432 |
+
img_dict = {}
|
| 433 |
+
img_dict["file_name"] = str(img_id).zfill(5) + ".png"
|
| 434 |
+
img_dict["id"] = img_id
|
| 435 |
+
img_dict["width"] = image_width
|
| 436 |
+
img_dict["height"] = image_height
|
| 437 |
+
|
| 438 |
+
if vis_fp:
|
| 439 |
+
os.makedirs(save_dir.rstrip("/") + "_aux", exist_ok=True)
|
| 440 |
+
visualize_room_polygons(
|
| 441 |
+
mask[0],
|
| 442 |
+
room_polygons,
|
| 443 |
+
list(ROOM_NAMES.values()),
|
| 444 |
+
save_path=f"{save_dir.rstrip('/') + '_aux'}/{str(img_id).zfill(5)}_room.png",
|
| 445 |
+
)
|
| 446 |
+
visualize_room_polygons(
|
| 447 |
+
mask[1],
|
| 448 |
+
icon_polygons,
|
| 449 |
+
list(ICON_NAMES.values()),
|
| 450 |
+
bg_polygons=room_polygons,
|
| 451 |
+
save_path=f"{save_dir.rstrip('/') + '_aux'}/{str(img_id).zfill(5)}_icon.png",
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
#### FILTER NON-USE TYPES
|
| 455 |
+
# DROP BG
|
| 456 |
+
room_skip_types = [0]
|
| 457 |
+
filtered_room_polygons = remove_polygons_by_type(room_polygons, skip_types=room_skip_types)
|
| 458 |
+
# visualize_room_polygons(mask[0], filtered_room_polygons, list(ROOM_NAMES.values()),
|
| 459 |
+
# save_path=f"{save_dir.rstrip('/') + '_aux'}/{str(img_id).zfill(5)}_room_filtered.png")
|
| 460 |
+
|
| 461 |
+
# Exclude all furnitures, excepts window, door
|
| 462 |
+
icon_skip_types = [0, *list(range(3, 11))]
|
| 463 |
+
filtered_icon_polygons = remove_polygons_by_type(icon_polygons, skip_types=icon_skip_types)
|
| 464 |
+
# visualize_room_polygons(mask[1], filtered_icon_polygons, list(ICON_NAMES.values()),
|
| 465 |
+
# bg_polygons=room_polygons, save_path=f"{save_dir.rstrip('/') + '_aux'}/{str(img_id).zfill(5)}_icon_filtered.png")
|
| 466 |
+
|
| 467 |
+
#### COMBINED
|
| 468 |
+
combined_polygons = merge_rooms_and_icons(filtered_room_polygons, filtered_icon_polygons)
|
| 469 |
+
|
| 470 |
+
filtered_mask1 = mask[0].copy()
|
| 471 |
+
filtered_mask1[np.isin(mask[0], room_skip_types)] = 0
|
| 472 |
+
|
| 473 |
+
filtered_mask2 = mask[1].copy()
|
| 474 |
+
filtered_mask2[np.isin(mask[1], icon_skip_types)] = 0
|
| 475 |
+
filtered_mask2[filtered_mask2 != 0] += 11
|
| 476 |
+
|
| 477 |
+
filtered_mask = np.where(filtered_mask2 != 0, filtered_mask2, filtered_mask1)
|
| 478 |
+
|
| 479 |
+
new_filtered_mask = filtered_mask.copy()
|
| 480 |
+
for src_type, dest_type in class_mapping_dict.items():
|
| 481 |
+
if dest_type is None:
|
| 482 |
+
continue
|
| 483 |
+
new_filtered_mask[filtered_mask == src_type] = dest_type + 1
|
| 484 |
+
# filtered_mask = new_filtered_mask
|
| 485 |
+
|
| 486 |
+
binary_mask = np.zeros_like(filtered_mask)
|
| 487 |
+
binary_mask = np.where((mask[0] + mask[1]) != 0, 1, 0).astype(np.uint8)
|
| 488 |
+
filled_mask = fill_holes_in_mask(binary_mask)
|
| 489 |
+
cv2.imwrite(
|
| 490 |
+
f"{save_dir.rstrip('/') + '_aux'}/{str(img_id).zfill(5) + '_mask.png'}", filled_mask.astype(np.uint8) * 255
|
| 491 |
+
)
|
| 492 |
+
# visualize_room_polygons(combined_mask, combined_polygons, list(ROOM_NAMES.values()) + list(ICON_NAMES.values()), save_path=f"{save_dir}/{str(img_id).zfill(5)}_combined.png")
|
| 493 |
+
|
| 494 |
+
save_image(
|
| 495 |
+
f"{args.data_root}/{image_set['folder']}/F1_scaled.png",
|
| 496 |
+
f"{save_dir}/{str(img_id).zfill(5) + '.png'}",
|
| 497 |
+
mask=filled_mask,
|
| 498 |
+
)
|
| 499 |
+
if vis_fp:
|
| 500 |
+
save_image(
|
| 501 |
+
f"{args.data_root}/{image_set['folder']}/F1_scaled.png",
|
| 502 |
+
f"{save_dir.rstrip('/') + '_aux'}/{str(img_id).zfill(5) + '_org.png'}",
|
| 503 |
+
mask=None,
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
output_polygon_list = []
|
| 507 |
+
combined_polygon_list = []
|
| 508 |
+
for poly_ind, (polygon, poly_type) in enumerate(combined_polygons):
|
| 509 |
+
poly_shapely = Polygon(polygon)
|
| 510 |
+
area = poly_shapely.area
|
| 511 |
+
|
| 512 |
+
org_poly_type = poly_type
|
| 513 |
+
poly_type = class_mapping_dict[poly_type]
|
| 514 |
+
if poly_type is None:
|
| 515 |
+
continue
|
| 516 |
+
|
| 517 |
+
if poly_type not in door_window_index and area < 100:
|
| 518 |
+
continue
|
| 519 |
+
if poly_type in door_window_index and area < 1:
|
| 520 |
+
continue
|
| 521 |
+
|
| 522 |
+
rectangle_shapely = poly_shapely.envelope
|
| 523 |
+
polygon = np.array(polygon)
|
| 524 |
+
|
| 525 |
+
### here we convert door/window annotation into a single line
|
| 526 |
+
if poly_type in door_window_index and wd2line:
|
| 527 |
+
if polygon.shape[0] > 4:
|
| 528 |
+
if len(polygon) == 5 and (polygon[0] == polygon[-1]).all():
|
| 529 |
+
polygon = polygon[:-1] # drop last point since it is same as first
|
| 530 |
+
else:
|
| 531 |
+
bounding_rect = np.array(poly_shapely.minimum_rotated_rectangle.exterior.coords)
|
| 532 |
+
polygon = bounding_rect[:4]
|
| 533 |
+
|
| 534 |
+
assert polygon.shape[0] == 4
|
| 535 |
+
midp_1 = (polygon[0] + polygon[1]) / 2
|
| 536 |
+
midp_2 = (polygon[1] + polygon[2]) / 2
|
| 537 |
+
midp_3 = (polygon[2] + polygon[3]) / 2
|
| 538 |
+
midp_4 = (polygon[3] + polygon[0]) / 2
|
| 539 |
+
|
| 540 |
+
dist_1_3 = np.square(midp_1 - midp_3).sum()
|
| 541 |
+
dist_2_4 = np.square(midp_2 - midp_4).sum()
|
| 542 |
+
if dist_1_3 > dist_2_4:
|
| 543 |
+
polygon = np.row_stack([midp_1, midp_3])
|
| 544 |
+
else:
|
| 545 |
+
polygon = np.row_stack([midp_2, midp_4])
|
| 546 |
+
|
| 547 |
+
coco_seg_poly = []
|
| 548 |
+
poly_sorted = resort_corners(polygon)
|
| 549 |
+
|
| 550 |
+
for p in poly_sorted:
|
| 551 |
+
coco_seg_poly += list(p)
|
| 552 |
+
|
| 553 |
+
# Slightly wider bounding box
|
| 554 |
+
bb_x, bb_y = rectangle_shapely.exterior.xy
|
| 555 |
+
coco_bb = create_coco_bounding_box(bb_x, bb_y, image_width, image_height, bound_pad=2)
|
| 556 |
+
|
| 557 |
+
coco_annotation_dict = {
|
| 558 |
+
"segmentation": [coco_seg_poly],
|
| 559 |
+
"area": area,
|
| 560 |
+
"iscrowd": 0,
|
| 561 |
+
"image_id": img_id,
|
| 562 |
+
"bbox": coco_bb,
|
| 563 |
+
"category_id": poly_type,
|
| 564 |
+
"id": instance_id,
|
| 565 |
+
}
|
| 566 |
+
coco_annotation_dict_list.append(coco_annotation_dict)
|
| 567 |
+
instance_id += 1
|
| 568 |
+
|
| 569 |
+
combined_polygon_list.append([np.array(coco_seg_poly).reshape(-1, 2), org_poly_type])
|
| 570 |
+
output_polygon_list.append([np.array(coco_seg_poly).reshape(-1, 2), poly_type + 1])
|
| 571 |
+
|
| 572 |
+
#### end split_file loop
|
| 573 |
+
save_dict["images"].append(img_dict)
|
| 574 |
+
save_dict["annotations"] += coco_annotation_dict_list
|
| 575 |
+
|
| 576 |
+
json_path = f"{annos_folder}/{str(img_id).zfill(5) + '.json'}"
|
| 577 |
+
with open(json_path, "w") as f:
|
| 578 |
+
json.dump(save_dict, f)
|
| 579 |
+
|
| 580 |
+
if vis_fp:
|
| 581 |
+
visualize_room_polygons(
|
| 582 |
+
filtered_mask,
|
| 583 |
+
combined_polygon_list,
|
| 584 |
+
list(ROOM_NAMES.values()) + ["window", "door"],
|
| 585 |
+
save_path=f"{save_dir.rstrip('/') + '_aux'}/{str(img_id).zfill(5)}_combined.png",
|
| 586 |
+
)
|
| 587 |
+
visualize_room_polygons(
|
| 588 |
+
new_filtered_mask,
|
| 589 |
+
output_polygon_list,
|
| 590 |
+
["null"] + list(class_to_index_dict.keys()),
|
| 591 |
+
save_path=f"{save_dir.rstrip('/') + '_aux'}/{str(img_id).zfill(5)}_final.png",
|
| 592 |
+
)
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
def prepare_dict(categories_dict):
|
| 596 |
+
save_dict = {"images": [], "annotations": [], "categories": []}
|
| 597 |
+
for key, value in categories_dict.items():
|
| 598 |
+
type_dict = {"supercategory": "room", "id": value, "name": key}
|
| 599 |
+
save_dict["categories"].append(type_dict)
|
| 600 |
+
return save_dict
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
if __name__ == "__main__":
|
| 604 |
+
args = config()
|
| 605 |
+
|
| 606 |
+
### prepare
|
| 607 |
+
outFolder = args.output
|
| 608 |
+
if not os.path.exists(outFolder):
|
| 609 |
+
os.mkdir(outFolder)
|
| 610 |
+
|
| 611 |
+
annotation_outFolder = os.path.join(outFolder, "annotations_json")
|
| 612 |
+
if not os.path.exists(annotation_outFolder):
|
| 613 |
+
os.mkdir(annotation_outFolder)
|
| 614 |
+
|
| 615 |
+
annos_train_folder = os.path.join(annotation_outFolder, "train")
|
| 616 |
+
annos_val_folder = os.path.join(annotation_outFolder, "val")
|
| 617 |
+
annos_test_folder = os.path.join(annotation_outFolder, "test")
|
| 618 |
+
os.makedirs(annos_train_folder, exist_ok=True)
|
| 619 |
+
os.makedirs(annos_val_folder, exist_ok=True)
|
| 620 |
+
os.makedirs(annos_test_folder, exist_ok=True)
|
| 621 |
+
|
| 622 |
+
train_img_folder = os.path.join(outFolder, "train")
|
| 623 |
+
val_img_folder = os.path.join(outFolder, "val")
|
| 624 |
+
test_img_folder = os.path.join(outFolder, "test")
|
| 625 |
+
|
| 626 |
+
for img_folder in [train_img_folder, val_img_folder, test_img_folder]:
|
| 627 |
+
if not os.path.exists(img_folder):
|
| 628 |
+
os.mkdir(img_folder)
|
| 629 |
+
|
| 630 |
+
coco_train_json_path = os.path.join(annotation_outFolder, "train.json")
|
| 631 |
+
coco_val_json_path = os.path.join(annotation_outFolder, "val.json")
|
| 632 |
+
coco_test_json_path = os.path.join(annotation_outFolder, "test.json")
|
| 633 |
+
|
| 634 |
+
### begin processing
|
| 635 |
+
start_scene_id = 3500 # following index of s3d data
|
| 636 |
+
split_set = ["train.txt", "val.txt", "test.txt"]
|
| 637 |
+
save_folders = [train_img_folder, val_img_folder, test_img_folder]
|
| 638 |
+
coco_json_paths = [coco_train_json_path, coco_val_json_path, coco_test_json_path]
|
| 639 |
+
annos_folders = [annos_train_folder, annos_val_folder, annos_test_folder]
|
| 640 |
+
|
| 641 |
+
def wrapper(scene_id):
|
| 642 |
+
image_set = dataset[scene_id]
|
| 643 |
+
process_floorplan(
|
| 644 |
+
image_set,
|
| 645 |
+
scene_id,
|
| 646 |
+
start_scene_id,
|
| 647 |
+
args,
|
| 648 |
+
save_dir,
|
| 649 |
+
annos_folder,
|
| 650 |
+
use_org_cc5k_classs=True,
|
| 651 |
+
vis_fp=scene_id < 100,
|
| 652 |
+
wd2line=not args.disable_wd2line,
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
def worker_init(dataset_obj):
|
| 656 |
+
# Store dataset as global to avoid pickling issues
|
| 657 |
+
global dataset
|
| 658 |
+
dataset = dataset_obj
|
| 659 |
+
|
| 660 |
+
for split_id, split_file in enumerate(split_set):
|
| 661 |
+
dataset = FloorplanSVG(args.data_root, split_file, format="txt", original_size=False)
|
| 662 |
+
save_dir = save_folders[split_id]
|
| 663 |
+
json_path = coco_json_paths[split_id]
|
| 664 |
+
print(f"############# {split_file}")
|
| 665 |
+
|
| 666 |
+
annos_folder = annos_folders[split_id]
|
| 667 |
+
num_processes = 16
|
| 668 |
+
with Pool(num_processes, initializer=worker_init, initargs=(dataset,)) as p:
|
| 669 |
+
indices = range(len(dataset))
|
| 670 |
+
list(tqdm(p.imap(wrapper, indices), total=len(dataset)))
|
| 671 |
+
|
| 672 |
+
start_scene_id += len(dataset)
|
data_preprocess/cubicasa5k/floorplan_extraction.py
ADDED
|
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import glob
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
from multiprocessing import Pool
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
import cv2
|
| 10 |
+
import numpy as np
|
| 11 |
+
from shapely.geometry import Polygon
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
| 15 |
+
from common_utils import resort_corners
|
| 16 |
+
from create_coco_cc5k import create_coco_bounding_box
|
| 17 |
+
|
| 18 |
+
from util.plot_utils import plot_semantic_rich_floorplan_opencv
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def plot_floor(output_coco_polygons, categories_dict, img_w, img_h, save_path, door_window_index=[10, 9]):
|
| 22 |
+
gt_sem_rich = []
|
| 23 |
+
for j, (poly, poly_type) in enumerate(output_coco_polygons):
|
| 24 |
+
corners = np.array(poly).reshape(-1, 2).astype(np.int32)
|
| 25 |
+
# corners_flip_y = corners.copy()
|
| 26 |
+
# corners_flip_y[:,1] = 255 - corners_flip_y[:,1]
|
| 27 |
+
# corners = corners_flip_y
|
| 28 |
+
gt_sem_rich.append([corners, poly_type])
|
| 29 |
+
# plot_semantic_rich_floorplan_nicely(gt_sem_rich, save_path, prec=None, rec=None,
|
| 30 |
+
# plot_text=True, is_bw=False,
|
| 31 |
+
# door_window_index=door_window_index,
|
| 32 |
+
# img_w=img_w,
|
| 33 |
+
# img_h=img_h,
|
| 34 |
+
# semantics_label_mapping=get_dataset_class_labels(categories_dict),
|
| 35 |
+
# )
|
| 36 |
+
plot_semantic_rich_floorplan_opencv(
|
| 37 |
+
gt_sem_rich,
|
| 38 |
+
save_path,
|
| 39 |
+
img_w=img_w,
|
| 40 |
+
img_h=img_h,
|
| 41 |
+
door_window_index=door_window_index,
|
| 42 |
+
semantics_label_mapping=get_dataset_class_labels(categories_dict),
|
| 43 |
+
is_bw=False,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def prepare_dict(categories_dict):
|
| 48 |
+
save_dict = {"images": [], "annotations": [], "categories": categories_dict}
|
| 49 |
+
return save_dict
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def extract_polygons_from_mask(binary_mask, output_mask_path):
|
| 53 |
+
"""
|
| 54 |
+
Extract polygons from a binary mask where regions with value 1 are polygons
|
| 55 |
+
and background regions have value 0.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
binary_mask (numpy.ndarray): Binary mask with shape (H, W), where 1 represents
|
| 59 |
+
the polygon regions and 0 represents the background.
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
list: A list of polygons, where each polygon is represented as a list of (x, y) coordinates.
|
| 63 |
+
"""
|
| 64 |
+
# Ensure the mask is binary (0 and 1)
|
| 65 |
+
binary_mask = (binary_mask > 0).astype(np.uint8)
|
| 66 |
+
|
| 67 |
+
# Find contours in the binary mask
|
| 68 |
+
contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 69 |
+
|
| 70 |
+
# Extract polygons from contours
|
| 71 |
+
polygons = []
|
| 72 |
+
for contour in contours:
|
| 73 |
+
# Approximate the contour to reduce the number of points
|
| 74 |
+
epsilon = 0.001 * cv2.arcLength(contour, True) # Adjust epsilon for more/less detail
|
| 75 |
+
approx_polygon = cv2.approxPolyDP(contour, epsilon, True)
|
| 76 |
+
polygons.append(approx_polygon.squeeze().tolist()) # Convert to list of (x, y) points
|
| 77 |
+
|
| 78 |
+
# Convert binary_mask to a 3-channel image to draw colored polylines
|
| 79 |
+
binary_mask_colored = cv2.cvtColor(binary_mask * 255, cv2.COLOR_GRAY2BGR)
|
| 80 |
+
|
| 81 |
+
# Plot polygons on the binary mask with green color
|
| 82 |
+
for polygon in polygons:
|
| 83 |
+
points = np.array(polygon, dtype=np.int32)
|
| 84 |
+
cv2.polylines(binary_mask_colored, [points], isClosed=True, color=(0, 0, 255), thickness=10)
|
| 85 |
+
|
| 86 |
+
cv2.imwrite(output_mask_path, binary_mask_colored)
|
| 87 |
+
|
| 88 |
+
return polygons
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def read_polygons_from_json(json_file):
|
| 92 |
+
with open(json_file, "r") as f:
|
| 93 |
+
data = json.load(f)
|
| 94 |
+
category_dict = data["categories"]
|
| 95 |
+
polygons = [data["annotations"][i]["segmentation"][0] for i in range(len(data["annotations"]))]
|
| 96 |
+
poly_types = [data["annotations"][i]["category_id"] for i in range(len(data["annotations"]))]
|
| 97 |
+
source_misc = [data["annotations"][i] for i in range(len(data["annotations"]))]
|
| 98 |
+
source_polygons = [(polygons[i], poly_types[i]) for i in range(len(polygons))]
|
| 99 |
+
|
| 100 |
+
return source_polygons, source_misc, category_dict
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def get_dataset_class_labels(category_dict):
|
| 104 |
+
return {category_dict[i]["id"]: category_dict[i]["name"] for i in range(len(category_dict))}
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def check_all_window_door_inside(polygons, door_window_index):
|
| 108 |
+
flag = all([poly_type in door_window_index for _, poly_type in polygons])
|
| 109 |
+
return flag
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def extract_region_and_annotation(
|
| 113 |
+
source_image,
|
| 114 |
+
source_annot_path,
|
| 115 |
+
region_polygons,
|
| 116 |
+
start_image_id,
|
| 117 |
+
output_image_dir="output",
|
| 118 |
+
output_annot_dir="annotations",
|
| 119 |
+
output_aux_dir="output_aux",
|
| 120 |
+
vis_aux=True,
|
| 121 |
+
):
|
| 122 |
+
"""
|
| 123 |
+
Extract regions of the floorplan from the source image based on polygons
|
| 124 |
+
and generate annotations.
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
source_image (numpy.ndarray): The source image (H, W, 3).
|
| 128 |
+
polygons (list): List of polygons, where each polygon is a list of (x, y) coordinates.
|
| 129 |
+
output_dir (str): Directory to save the extracted regions and annotations.
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
list: A list of annotations for each extracted region.
|
| 133 |
+
"""
|
| 134 |
+
door_window_index = [10, 9]
|
| 135 |
+
source_polygons, source_misc, categories_dict = read_polygons_from_json(source_annot_path)
|
| 136 |
+
source_img_id = os.path.basename(source_annot_path).split(".")[0].zfill(5)
|
| 137 |
+
if vis_aux:
|
| 138 |
+
gt_sem_rich_path = os.path.join(output_aux_dir, "{}_org_floor.png".format(source_img_id))
|
| 139 |
+
plot_floor(
|
| 140 |
+
source_polygons,
|
| 141 |
+
categories_dict,
|
| 142 |
+
source_image.shape[1],
|
| 143 |
+
source_image.shape[0],
|
| 144 |
+
gt_sem_rich_path,
|
| 145 |
+
door_window_index=door_window_index,
|
| 146 |
+
)
|
| 147 |
+
margin = 10
|
| 148 |
+
img_id = start_image_id
|
| 149 |
+
|
| 150 |
+
# each region polygon corresponds to an image
|
| 151 |
+
for i, polygon in enumerate(region_polygons):
|
| 152 |
+
instance_id = 0
|
| 153 |
+
output_coco_polygons = []
|
| 154 |
+
# Create a mask for the current polygon
|
| 155 |
+
mask = np.zeros(source_image.shape[:2], dtype=np.uint8)
|
| 156 |
+
points = np.array(polygon, dtype=np.int32)
|
| 157 |
+
cv2.fillPoly(mask, [points], 255)
|
| 158 |
+
|
| 159 |
+
# Crop the ROI to the bounding box of the polygon
|
| 160 |
+
x, y, w, h = cv2.boundingRect(points)
|
| 161 |
+
# Expand the bounding box by the margin
|
| 162 |
+
x_expanded = max(x - margin, 0)
|
| 163 |
+
y_expanded = max(y - margin, 0)
|
| 164 |
+
w_expanded = min(x + w + margin, source_image.shape[1]) - x_expanded
|
| 165 |
+
h_expanded = min(y + h + margin, source_image.shape[0]) - y_expanded
|
| 166 |
+
|
| 167 |
+
x, y, w, h = x_expanded, y_expanded, w_expanded, h_expanded
|
| 168 |
+
cropped_roi = source_image[y : y + h, x : x + w]
|
| 169 |
+
|
| 170 |
+
save_dict = prepare_dict(categories_dict)
|
| 171 |
+
|
| 172 |
+
# Create an annotation for the extracted region
|
| 173 |
+
img_dict = {}
|
| 174 |
+
img_dict["file_name"] = f"{str(img_id).zfill(5)}_{source_img_id}.png"
|
| 175 |
+
img_dict["id"] = img_id
|
| 176 |
+
img_dict["width"] = w
|
| 177 |
+
img_dict["height"] = h
|
| 178 |
+
|
| 179 |
+
# Save the cropped ROI
|
| 180 |
+
roi_filename = f"{output_image_dir}/{str(img_id).zfill(5)}_{source_img_id}.png"
|
| 181 |
+
cv2.imwrite(roi_filename, cropped_roi)
|
| 182 |
+
|
| 183 |
+
bounding_box = np.array([x, y, x + w, y + h])
|
| 184 |
+
|
| 185 |
+
# Convert source polygons to NumPy arrays for vectorized operations
|
| 186 |
+
source_polygons_np = [np.array(src_poly[0]).reshape(-1, 2) for src_poly in source_polygons]
|
| 187 |
+
assert len(source_polygons_np) == len(source_polygons)
|
| 188 |
+
|
| 189 |
+
coco_annotation_dict_list = []
|
| 190 |
+
# Iterate through the polygons and filter those inside the bounding box
|
| 191 |
+
for j, tmp in enumerate(source_polygons_np):
|
| 192 |
+
# Compute the bounding box of the current polygon
|
| 193 |
+
poly_bbox = np.hstack([np.min(tmp, axis=0), np.max(tmp, axis=0)])
|
| 194 |
+
|
| 195 |
+
# Check if the polygon is outside the bounding box
|
| 196 |
+
if np.any(poly_bbox[:2] < bounding_box[:2]) or np.any(poly_bbox[2:] > bounding_box[2:]):
|
| 197 |
+
continue
|
| 198 |
+
|
| 199 |
+
# Scale the polygon coordinates relative to the top-left corner of the bounding box
|
| 200 |
+
scaled_polygon = tmp - bounding_box[:2]
|
| 201 |
+
|
| 202 |
+
coco_seg_poly = []
|
| 203 |
+
poly_sorted = resort_corners(scaled_polygon)
|
| 204 |
+
# image = draw_polygon_on_image(image, poly_shapely, "test_poly.jpg")
|
| 205 |
+
|
| 206 |
+
for p in poly_sorted:
|
| 207 |
+
coco_seg_poly += list(p)
|
| 208 |
+
|
| 209 |
+
if len(scaled_polygon) == 2:
|
| 210 |
+
area = source_misc[j]["area"]
|
| 211 |
+
coco_bb = source_misc[j]["bbox"]
|
| 212 |
+
# shift the bounding box
|
| 213 |
+
coco_bb[0] -= bounding_box[0]
|
| 214 |
+
coco_bb[1] -= bounding_box[1]
|
| 215 |
+
else:
|
| 216 |
+
poly_shapely = Polygon(scaled_polygon)
|
| 217 |
+
area = poly_shapely.area
|
| 218 |
+
rectangle_shapely = poly_shapely.envelope
|
| 219 |
+
|
| 220 |
+
# Slightly wider bounding box
|
| 221 |
+
bb_x, bb_y = rectangle_shapely.exterior.xy
|
| 222 |
+
coco_bb = create_coco_bounding_box(bb_x, bb_y, w, h, bound_pad=2)
|
| 223 |
+
|
| 224 |
+
coco_annotation_dict = {
|
| 225 |
+
"segmentation": [coco_seg_poly],
|
| 226 |
+
"area": area,
|
| 227 |
+
"iscrowd": 0,
|
| 228 |
+
"image_id": img_id,
|
| 229 |
+
"bbox": coco_bb,
|
| 230 |
+
"category_id": source_polygons[j][1],
|
| 231 |
+
"id": instance_id,
|
| 232 |
+
}
|
| 233 |
+
coco_annotation_dict_list.append(coco_annotation_dict)
|
| 234 |
+
output_coco_polygons.append([coco_seg_poly, source_polygons[j][1]])
|
| 235 |
+
|
| 236 |
+
# Remove after obtaining the polygon
|
| 237 |
+
# source_polygons.pop(j)
|
| 238 |
+
# source_misc.pop(j)
|
| 239 |
+
instance_id += 1
|
| 240 |
+
|
| 241 |
+
# skip if just windows and doors are inside
|
| 242 |
+
if check_all_window_door_inside(output_coco_polygons, door_window_index):
|
| 243 |
+
instance_id -= len(coco_annotation_dict_list)
|
| 244 |
+
continue
|
| 245 |
+
|
| 246 |
+
save_dict["images"].append(img_dict)
|
| 247 |
+
save_dict["annotations"] += coco_annotation_dict_list
|
| 248 |
+
|
| 249 |
+
if vis_aux:
|
| 250 |
+
gt_sem_rich_path = os.path.join(
|
| 251 |
+
output_aux_dir, "{}_{}_floor.png".format(str(img_id).zfill(5), source_img_id)
|
| 252 |
+
)
|
| 253 |
+
plot_floor(
|
| 254 |
+
output_coco_polygons, categories_dict, w, h, gt_sem_rich_path, door_window_index=door_window_index
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
# Save annotations to a JSON file
|
| 258 |
+
json_path = f"{output_annot_dir}/{str(img_id).zfill(5)}_{source_img_id}.json"
|
| 259 |
+
with open(json_path, "w") as f:
|
| 260 |
+
json.dump(save_dict, f)
|
| 261 |
+
|
| 262 |
+
img_id += 1
|
| 263 |
+
|
| 264 |
+
start_image_id = img_id
|
| 265 |
+
|
| 266 |
+
return start_image_id
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def config():
|
| 270 |
+
a = argparse.ArgumentParser(description="Generate coco format data for Structured3D")
|
| 271 |
+
a.add_argument(
|
| 272 |
+
"--data_root", default="Structured3D_panorama", type=str, help="path to raw Structured3D_panorama folder"
|
| 273 |
+
)
|
| 274 |
+
a.add_argument("--output", default="coco_cubicasa5k", type=str, help="path to output folder")
|
| 275 |
+
|
| 276 |
+
args = a.parse_args()
|
| 277 |
+
return args
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
# Example usage
|
| 281 |
+
if __name__ == "__main__":
|
| 282 |
+
args = config()
|
| 283 |
+
|
| 284 |
+
### prepare
|
| 285 |
+
outFolder = args.output
|
| 286 |
+
if not os.path.exists(outFolder):
|
| 287 |
+
os.mkdir(outFolder)
|
| 288 |
+
|
| 289 |
+
annotation_outFolder = os.path.join(outFolder, "annotations_json")
|
| 290 |
+
if not os.path.exists(annotation_outFolder):
|
| 291 |
+
os.mkdir(annotation_outFolder)
|
| 292 |
+
|
| 293 |
+
annos_train_folder = os.path.join(annotation_outFolder, "train")
|
| 294 |
+
annos_val_folder = os.path.join(annotation_outFolder, "val")
|
| 295 |
+
annos_test_folder = os.path.join(annotation_outFolder, "test")
|
| 296 |
+
os.makedirs(annos_train_folder, exist_ok=True)
|
| 297 |
+
os.makedirs(annos_val_folder, exist_ok=True)
|
| 298 |
+
os.makedirs(annos_test_folder, exist_ok=True)
|
| 299 |
+
|
| 300 |
+
train_img_folder = os.path.join(outFolder, "train")
|
| 301 |
+
val_img_folder = os.path.join(outFolder, "val")
|
| 302 |
+
test_img_folder = os.path.join(outFolder, "test")
|
| 303 |
+
|
| 304 |
+
for img_folder in [train_img_folder, val_img_folder, test_img_folder]:
|
| 305 |
+
if not os.path.exists(img_folder):
|
| 306 |
+
os.mkdir(img_folder)
|
| 307 |
+
|
| 308 |
+
### begin processing
|
| 309 |
+
start_image_id = 3500
|
| 310 |
+
save_folders = [train_img_folder, val_img_folder, test_img_folder]
|
| 311 |
+
annos_folders = [annos_train_folder, annos_val_folder, annos_test_folder]
|
| 312 |
+
splits = ["train", "val", "test"]
|
| 313 |
+
|
| 314 |
+
def wrapper(index):
|
| 315 |
+
image_path, annot_path, mask_path = packed_input_files[index]
|
| 316 |
+
cur_image_id = int(os.path.basename(image_path).split(".")[0])
|
| 317 |
+
binary_mask = cv2.imread(mask_path)[:, :, -1]
|
| 318 |
+
source_image = cv2.imread(image_path, cv2.IMREAD_COLOR)
|
| 319 |
+
# Extract polygons
|
| 320 |
+
region_polygons = extract_polygons_from_mask(
|
| 321 |
+
binary_mask, output_mask_path=f"{save_aux_path}/{str(cur_image_id).zfill(5)}_polylines.png"
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
return extract_region_and_annotation(
|
| 325 |
+
source_image,
|
| 326 |
+
annot_path,
|
| 327 |
+
region_polygons,
|
| 328 |
+
start_image_id + index * 10,
|
| 329 |
+
save_path,
|
| 330 |
+
save_anno_path,
|
| 331 |
+
save_aux_path,
|
| 332 |
+
vis_aux=True,
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
def worker_init(input_files_object):
|
| 336 |
+
# Store dataset as global to avoid pickling issues
|
| 337 |
+
global packed_input_files
|
| 338 |
+
packed_input_files = input_files_object
|
| 339 |
+
|
| 340 |
+
for i, split in enumerate(splits):
|
| 341 |
+
image_files = sorted(glob.glob(f"{args.data_root}/{split}/*.png"))
|
| 342 |
+
image_id_list = [os.path.basename(image_path).split(".")[0] for image_path in image_files]
|
| 343 |
+
anno_files = [f"{args.data_root}/annotations_json/{split}/{id_}.json" for id_ in image_id_list]
|
| 344 |
+
mask_files = [f"{args.data_root}/{split}_aux/{id_}_mask.png" for id_ in image_id_list]
|
| 345 |
+
save_path = save_folders[i]
|
| 346 |
+
save_anno_path = annos_folders[i]
|
| 347 |
+
save_aux_path = save_path.rstrip("/") + "_aux"
|
| 348 |
+
os.makedirs(save_aux_path, exist_ok=True)
|
| 349 |
+
|
| 350 |
+
# for j, (image_path, anno_path, mask_path) in enumerate(zip(image_files, anno_files, mask_files)):
|
| 351 |
+
# cur_image_id = int(os.path.basename(image_path).split('.')[0])
|
| 352 |
+
# binary_mask = cv2.imread(mask_path)[:,:,-1]
|
| 353 |
+
# source_image = cv2.imread(image_path, cv2.IMREAD_COLOR)
|
| 354 |
+
|
| 355 |
+
# # Extract polygons
|
| 356 |
+
# polygons = extract_polygons_from_mask(binary_mask, output_mask_path=f'{save_aux_path}/{str(cur_image_id).zfill(5)}_polylines.png')
|
| 357 |
+
# # # skip if only one polygon (floorplan)
|
| 358 |
+
# # if len(polygons) == 1:
|
| 359 |
+
# # print(f"Skipping {image_path} with only one polygon")
|
| 360 |
+
# # with open(anno_path, 'r') as f:
|
| 361 |
+
# # data = json.load(f)
|
| 362 |
+
# # # update image id
|
| 363 |
+
# # data['images'][0]['id'] = start_image_id
|
| 364 |
+
# # data['images'][0]["file_name"] = f'{str(start_image_id).zfill(5)}_{str(cur_image_id).zfill(5)}.png'
|
| 365 |
+
# # for anno in data['annotations']:
|
| 366 |
+
# # anno['image_id'] = start_image_id
|
| 367 |
+
|
| 368 |
+
# # with open(f"{save_anno_path}/{str(start_image_id).zfill(5)}_{str(cur_image_id).zfill(5)}.json", 'w') as f:
|
| 369 |
+
# # json.dump(data, f, indent=2)
|
| 370 |
+
# # shutil.copy(image_path, f"{save_path}/{str(start_image_id).zfill(5)}_{str(cur_image_id).zfill(5)}.png")
|
| 371 |
+
|
| 372 |
+
# # gt_sem_rich_path = os.path.join(save_aux_path, '{}_{}_floor.png'.format(str(start_image_id).zfill(5), str(cur_image_id).zfill(5)))
|
| 373 |
+
# # output_coco_polygons = [(x['segmentation'][0], x['category_id']) for x in data['annotations']]
|
| 374 |
+
# # plot_floor(output_coco_polygons, data['categories'], data['images'][0]['width'], data['images'][0]['height'], gt_sem_rich_path, door_window_index=[10, 9])
|
| 375 |
+
|
| 376 |
+
# # start_image_id += 1
|
| 377 |
+
# # continue
|
| 378 |
+
|
| 379 |
+
# # # Print the extracted polygons
|
| 380 |
+
# # print("Extracted polygons:")
|
| 381 |
+
# # for i, polygon in enumerate(polygons):
|
| 382 |
+
# # print(f"Polygon {i + 1}: {polygon}")
|
| 383 |
+
|
| 384 |
+
# start_image_id = extract_region_and_annotation(source_image,
|
| 385 |
+
# anno_path,
|
| 386 |
+
# polygons,
|
| 387 |
+
# start_image_id,
|
| 388 |
+
# output_image_dir=save_path,
|
| 389 |
+
# output_annot_dir=save_anno_path,
|
| 390 |
+
# output_aux_dir=save_aux_path,
|
| 391 |
+
# vis_aux=True)
|
| 392 |
+
|
| 393 |
+
packed_input_files = list(zip(image_files, anno_files, mask_files))
|
| 394 |
+
# for j in range(5):
|
| 395 |
+
# wrapper(j)
|
| 396 |
+
# exit(0)
|
| 397 |
+
|
| 398 |
+
num_processes = 16
|
| 399 |
+
with Pool(num_processes, initializer=worker_init, initargs=(packed_input_files,)) as p:
|
| 400 |
+
indices = [j for j in range(len(packed_input_files))]
|
| 401 |
+
list(tqdm(p.imap(wrapper, indices), total=len(indices)))
|
| 402 |
+
|
| 403 |
+
start_image_id += len(packed_input_files) * 10
|
data_preprocess/cubicasa5k/house.py
ADDED
|
@@ -0,0 +1,1131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
from xml.dom import minidom
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
from skimage.draw import polygon
|
| 7 |
+
from svg_utils import (
|
| 8 |
+
PolygonWall,
|
| 9 |
+
calc_distance,
|
| 10 |
+
get_direction,
|
| 11 |
+
get_gaussian2D,
|
| 12 |
+
get_icon,
|
| 13 |
+
get_icon_number,
|
| 14 |
+
get_points,
|
| 15 |
+
get_room_number,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
all_rooms = {
|
| 19 |
+
"Background": 0, # Not in data. The default outside label
|
| 20 |
+
"Alcove": 1,
|
| 21 |
+
"Attic": 2,
|
| 22 |
+
"Ballroom": 3,
|
| 23 |
+
"Bar": 4,
|
| 24 |
+
"Basement": 5,
|
| 25 |
+
"Bath": 6,
|
| 26 |
+
"Bedroom": 7,
|
| 27 |
+
"Below150cm": 8,
|
| 28 |
+
"CarPort": 9,
|
| 29 |
+
"Church": 10,
|
| 30 |
+
"Closet": 11,
|
| 31 |
+
"ConferenceRoom": 12,
|
| 32 |
+
"Conservatory": 13,
|
| 33 |
+
"Counter": 14,
|
| 34 |
+
"Den": 15,
|
| 35 |
+
"Dining": 16,
|
| 36 |
+
"DraughtLobby": 17,
|
| 37 |
+
"DressingRoom": 18,
|
| 38 |
+
"EatingArea": 19,
|
| 39 |
+
"Elevated": 20,
|
| 40 |
+
"Elevator": 21,
|
| 41 |
+
"Entry": 22,
|
| 42 |
+
"ExerciseRoom": 23,
|
| 43 |
+
"Garage": 24,
|
| 44 |
+
"Garbage": 25,
|
| 45 |
+
"Hall": 26,
|
| 46 |
+
"HallWay": 27,
|
| 47 |
+
"HotTub": 28,
|
| 48 |
+
"Kitchen": 29,
|
| 49 |
+
"Library": 30,
|
| 50 |
+
"LivingRoom": 31,
|
| 51 |
+
"Loft": 32,
|
| 52 |
+
"Lounge": 33,
|
| 53 |
+
"MediaRoom": 34,
|
| 54 |
+
"MeetingRoom": 35,
|
| 55 |
+
"Museum": 36,
|
| 56 |
+
"Nook": 37,
|
| 57 |
+
"Office": 38,
|
| 58 |
+
"OpenToBelow": 39,
|
| 59 |
+
"Outdoor": 40,
|
| 60 |
+
"Pantry": 41,
|
| 61 |
+
"Reception": 42,
|
| 62 |
+
"RecreationRoom": 43,
|
| 63 |
+
"RetailSpace": 44,
|
| 64 |
+
"Room": 45,
|
| 65 |
+
"Sanctuary": 46,
|
| 66 |
+
"Sauna": 47,
|
| 67 |
+
"ServiceRoom": 48,
|
| 68 |
+
"ServingArea": 49,
|
| 69 |
+
"Skylights": 50,
|
| 70 |
+
"Stable": 51,
|
| 71 |
+
"Stage": 52,
|
| 72 |
+
"StairWell": 53,
|
| 73 |
+
"Storage": 54,
|
| 74 |
+
"SunRoom": 55,
|
| 75 |
+
"SwimmingPool": 56,
|
| 76 |
+
"TechnicalRoom": 57,
|
| 77 |
+
"Theatre": 58,
|
| 78 |
+
"Undefined": 59,
|
| 79 |
+
"UserDefined": 60,
|
| 80 |
+
"Utility": 61,
|
| 81 |
+
"Wall": 62,
|
| 82 |
+
"Railing": 63,
|
| 83 |
+
"Stairs": 64,
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
rooms_selected = {
|
| 87 |
+
"Alcove": 11,
|
| 88 |
+
"Attic": 11,
|
| 89 |
+
"Ballroom": 11,
|
| 90 |
+
"Bar": 11,
|
| 91 |
+
"Basement": 11,
|
| 92 |
+
"Bath": 6,
|
| 93 |
+
"Bedroom": 5,
|
| 94 |
+
"CarPort": 10,
|
| 95 |
+
"Church": 11,
|
| 96 |
+
"Closet": 9,
|
| 97 |
+
"ConferenceRoom": 11,
|
| 98 |
+
"Conservatory": 11,
|
| 99 |
+
"Counter": 11,
|
| 100 |
+
"Den": 11,
|
| 101 |
+
"Dining": 4,
|
| 102 |
+
"DraughtLobby": 7,
|
| 103 |
+
"DressingRoom": 9,
|
| 104 |
+
"EatingArea": 4,
|
| 105 |
+
"Elevated": 11,
|
| 106 |
+
"Elevator": 11,
|
| 107 |
+
"Entry": 7,
|
| 108 |
+
"ExerciseRoom": 11,
|
| 109 |
+
"Garage": 10,
|
| 110 |
+
"Garbage": 11,
|
| 111 |
+
"Hall": 11,
|
| 112 |
+
"HallWay": 7,
|
| 113 |
+
"HotTub": 11,
|
| 114 |
+
"Kitchen": 3,
|
| 115 |
+
"Library": 11,
|
| 116 |
+
"LivingRoom": 4,
|
| 117 |
+
"Loft": 11,
|
| 118 |
+
"Lounge": 4,
|
| 119 |
+
"MediaRoom": 11,
|
| 120 |
+
"MeetingRoom": 11,
|
| 121 |
+
"Museum": 11,
|
| 122 |
+
"Nook": 11,
|
| 123 |
+
"Office": 11,
|
| 124 |
+
"OpenToBelow": 11,
|
| 125 |
+
"Outdoor": 1,
|
| 126 |
+
"Pantry": 11,
|
| 127 |
+
"Reception": 11,
|
| 128 |
+
"RecreationRoom": 11,
|
| 129 |
+
"RetailSpace": 11,
|
| 130 |
+
"Room": 11,
|
| 131 |
+
"Sanctuary": 11,
|
| 132 |
+
"Sauna": 6,
|
| 133 |
+
"ServiceRoom": 11,
|
| 134 |
+
"ServingArea": 11,
|
| 135 |
+
"Skylights": 11,
|
| 136 |
+
"Stable": 11,
|
| 137 |
+
"Stage": 11,
|
| 138 |
+
"StairWell": 11,
|
| 139 |
+
"Storage": 9,
|
| 140 |
+
"SunRoom": 11,
|
| 141 |
+
"SwimmingPool": 11,
|
| 142 |
+
"TechnicalRoom": 11,
|
| 143 |
+
"Theatre": 11,
|
| 144 |
+
"Undefined": 11,
|
| 145 |
+
"UserDefined": 11,
|
| 146 |
+
"Utility": 11,
|
| 147 |
+
"Background": 0, # Not in data. The default outside label
|
| 148 |
+
"Wall": 2,
|
| 149 |
+
"Railing": 8,
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
room_name_map = {
|
| 153 |
+
"Alcove": "Room",
|
| 154 |
+
"Attic": "Room",
|
| 155 |
+
"Ballroom": "Room",
|
| 156 |
+
"Bar": "Room",
|
| 157 |
+
"Basement": "Room",
|
| 158 |
+
"Bath": "Bath",
|
| 159 |
+
"Bedroom": "Bedroom",
|
| 160 |
+
"Below150cm": "Room",
|
| 161 |
+
"CarPort": "Garage",
|
| 162 |
+
"Church": "Room",
|
| 163 |
+
"Closet": "Storage",
|
| 164 |
+
"ConferenceRoom": "Room",
|
| 165 |
+
"Conservatory": "Room",
|
| 166 |
+
"Counter": "Room",
|
| 167 |
+
"Den": "Room",
|
| 168 |
+
"Dining": "Dining",
|
| 169 |
+
"DraughtLobby": "Entry",
|
| 170 |
+
"DressingRoom": "Storage",
|
| 171 |
+
"EatingArea": "Dining",
|
| 172 |
+
"Elevated": "Room",
|
| 173 |
+
"Elevator": "Room",
|
| 174 |
+
"Entry": "Entry",
|
| 175 |
+
"ExerciseRoom": "Room",
|
| 176 |
+
"Garage": "Garage",
|
| 177 |
+
"Garbage": "Room",
|
| 178 |
+
"Hall": "Room",
|
| 179 |
+
"HallWay": "Entry",
|
| 180 |
+
"HotTub": "Room",
|
| 181 |
+
"Kitchen": "Kitchen",
|
| 182 |
+
"Library": "Room",
|
| 183 |
+
"LivingRoom": "LivingRoom",
|
| 184 |
+
"Loft": "Room",
|
| 185 |
+
"Lounge": "LivingRoom",
|
| 186 |
+
"MediaRoom": "Room",
|
| 187 |
+
"MeetingRoom": "Room",
|
| 188 |
+
"Museum": "Room",
|
| 189 |
+
"Nook": "Room",
|
| 190 |
+
"Office": "Room",
|
| 191 |
+
"OpenToBelow": "Room",
|
| 192 |
+
"Outdoor": "Outdoor",
|
| 193 |
+
"Pantry": "Room",
|
| 194 |
+
"Reception": "Room",
|
| 195 |
+
"RecreationRoom": "Room",
|
| 196 |
+
"RetailSpace": "Room",
|
| 197 |
+
"Room": "Room",
|
| 198 |
+
"Sanctuary": "Room",
|
| 199 |
+
"Sauna": "Bath",
|
| 200 |
+
"ServiceRoom": "Room",
|
| 201 |
+
"ServingArea": "Room",
|
| 202 |
+
"Skylights": "Room",
|
| 203 |
+
"Stable": "Room",
|
| 204 |
+
"Stage": "Room",
|
| 205 |
+
"StairWell": "Room",
|
| 206 |
+
"Storage": "Storage",
|
| 207 |
+
"SunRoom": "Room",
|
| 208 |
+
"SwimmingPool": "Room",
|
| 209 |
+
"TechnicalRoom": "Room",
|
| 210 |
+
"Theatre": "Room",
|
| 211 |
+
"Undefined": "Room",
|
| 212 |
+
"UserDefined": "Room",
|
| 213 |
+
"Utility": "Room",
|
| 214 |
+
"Wall": "Wall",
|
| 215 |
+
"Railing": "Railing",
|
| 216 |
+
"Background": "Background",
|
| 217 |
+
} # Not in data. The default outside label
|
| 218 |
+
|
| 219 |
+
all_icons = {
|
| 220 |
+
"Empty": 0,
|
| 221 |
+
"Window": 1,
|
| 222 |
+
"Door": 2,
|
| 223 |
+
"BaseCabinet": 3,
|
| 224 |
+
"BaseCabinetRound": 4,
|
| 225 |
+
"BaseCabinetTriangle": 5,
|
| 226 |
+
"Bathtub": 6,
|
| 227 |
+
"BathtubRound": 7,
|
| 228 |
+
"Chimney": 8,
|
| 229 |
+
"Closet": 9,
|
| 230 |
+
"ClosetRound": 10,
|
| 231 |
+
"ClosetTriangle": 11,
|
| 232 |
+
"CoatCloset": 12,
|
| 233 |
+
"CoatRack": 13,
|
| 234 |
+
"CornerSink": 14,
|
| 235 |
+
"CounterTop": 15,
|
| 236 |
+
"DoubleSink": 16,
|
| 237 |
+
"DoubleSinkRight": 17,
|
| 238 |
+
"ElectricalAppliance": 18,
|
| 239 |
+
"Fireplace": 19,
|
| 240 |
+
"FireplaceCorner": 20,
|
| 241 |
+
"FireplaceRound": 21,
|
| 242 |
+
"GasStove": 22,
|
| 243 |
+
"Housing": 23,
|
| 244 |
+
"Jacuzzi": 24,
|
| 245 |
+
"PlaceForFireplace": 25,
|
| 246 |
+
"PlaceForFireplaceCorner": 26,
|
| 247 |
+
"PlaceForFireplaceRound": 27,
|
| 248 |
+
"RoundSink": 28,
|
| 249 |
+
"SaunaBenchHigh": 29,
|
| 250 |
+
"SaunaBenchLow": 30,
|
| 251 |
+
"SaunaBenchMid": 31,
|
| 252 |
+
"Shower": 32,
|
| 253 |
+
"ShowerCab": 33,
|
| 254 |
+
"ShowerScreen": 34,
|
| 255 |
+
"ShowerScreenRoundLeft": 35,
|
| 256 |
+
"ShowerScreenRoundRight": 36,
|
| 257 |
+
"SideSink": 37,
|
| 258 |
+
"Sink": 38,
|
| 259 |
+
"Toilet": 39,
|
| 260 |
+
"Urinal": 40,
|
| 261 |
+
"WallCabinet": 41,
|
| 262 |
+
"WaterTap": 42,
|
| 263 |
+
"WoodStove": 43,
|
| 264 |
+
"Misc": 44,
|
| 265 |
+
"SaunaBench": 45,
|
| 266 |
+
"SaunaStove": 46,
|
| 267 |
+
"WashingMachine": 47,
|
| 268 |
+
"IntegratedStove": 48,
|
| 269 |
+
"Dishwasher": 49,
|
| 270 |
+
"GeneralAppliance": 50,
|
| 271 |
+
"ShowerPlatform": 51,
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
icons_selected = {
|
| 275 |
+
"Window": 1,
|
| 276 |
+
"Door": 2,
|
| 277 |
+
"Closet": 3,
|
| 278 |
+
"ClosetRound": 3,
|
| 279 |
+
"ClosetTriangle": 3,
|
| 280 |
+
"CoatCloset": 3,
|
| 281 |
+
"CoatRack": 3,
|
| 282 |
+
"CounterTop": 3,
|
| 283 |
+
"Housing": 3,
|
| 284 |
+
"ElectricalAppliance": 4,
|
| 285 |
+
"WoodStove": 4,
|
| 286 |
+
"GasStove": 4,
|
| 287 |
+
"Toilet": 5,
|
| 288 |
+
"Urinal": 5,
|
| 289 |
+
"SideSink": 6,
|
| 290 |
+
"Sink": 6,
|
| 291 |
+
"RoundSink": 6,
|
| 292 |
+
"CornerSink": 6,
|
| 293 |
+
"DoubleSink": 6,
|
| 294 |
+
"DoubleSinkRight": 6,
|
| 295 |
+
"WaterTap": 6,
|
| 296 |
+
"SaunaBenchHigh": 7,
|
| 297 |
+
"SaunaBenchLow": 7,
|
| 298 |
+
"SaunaBenchMid": 7,
|
| 299 |
+
"SaunaBench": 7,
|
| 300 |
+
"Fireplace": 8,
|
| 301 |
+
"FireplaceCorner": 8,
|
| 302 |
+
"FireplaceRound": 8,
|
| 303 |
+
"PlaceForFireplace": 8,
|
| 304 |
+
"PlaceForFireplaceCorner": 8,
|
| 305 |
+
"PlaceForFireplaceRound": 8,
|
| 306 |
+
"Bathtub": 9,
|
| 307 |
+
"BathtubRound": 9,
|
| 308 |
+
"Chimney": 10,
|
| 309 |
+
"Misc": None,
|
| 310 |
+
"BaseCabinetRound": None,
|
| 311 |
+
"BaseCabinetTriangle": None,
|
| 312 |
+
"BaseCabinet": None,
|
| 313 |
+
"WallCabinet": None,
|
| 314 |
+
"Shower": None,
|
| 315 |
+
"ShowerCab": None,
|
| 316 |
+
"ShowerPlatform": None,
|
| 317 |
+
"ShowerScreen": None,
|
| 318 |
+
"ShowerScreenRoundRight": None,
|
| 319 |
+
"ShowerScreenRoundLeft": None,
|
| 320 |
+
"Jacuzzi": None,
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
icon_name_map = {
|
| 324 |
+
"Window": "Window",
|
| 325 |
+
"Door": "Door",
|
| 326 |
+
"Closet": "Closet",
|
| 327 |
+
"ClosetRound": "Closet",
|
| 328 |
+
"ClosetTriangle": "Closet",
|
| 329 |
+
"CoatCloset": "Closet",
|
| 330 |
+
"CoatRack": "Closet",
|
| 331 |
+
"CounterTop": "Closet",
|
| 332 |
+
"Housing": "Closet",
|
| 333 |
+
"ElectricalAppliance": "ElectricalAppliance",
|
| 334 |
+
"WoodStove": "ElectricalAppliance",
|
| 335 |
+
"GasStove": "ElectricalAppliance",
|
| 336 |
+
"SaunaStove": "ElectricalAppliance",
|
| 337 |
+
"Toilet": "Toilet",
|
| 338 |
+
"Urinal": "Toilet",
|
| 339 |
+
"SideSink": "Sink",
|
| 340 |
+
"Sink": "Sink",
|
| 341 |
+
"RoundSink": "Sink",
|
| 342 |
+
"CornerSink": "Sink",
|
| 343 |
+
"DoubleSink": "Sink",
|
| 344 |
+
"DoubleSinkRight": "Sink",
|
| 345 |
+
"WaterTap": "Sink",
|
| 346 |
+
"SaunaBenchHigh": "SaunaBench",
|
| 347 |
+
"SaunaBenchLow": "SaunaBench",
|
| 348 |
+
"SaunaBenchMid": "SaunaBench",
|
| 349 |
+
"SaunaBench": "SaunaBench",
|
| 350 |
+
"Fireplace": "Fireplace",
|
| 351 |
+
"FireplaceCorner": "Fireplace",
|
| 352 |
+
"FireplaceRound": "Fireplace",
|
| 353 |
+
"PlaceForFireplace": "Fireplace",
|
| 354 |
+
"PlaceForFireplaceCorner": "Fireplace",
|
| 355 |
+
"PlaceForFireplaceRound": "Fireplace",
|
| 356 |
+
"Bathtub": "Bathtub",
|
| 357 |
+
"BathtubRound": "Bathtub",
|
| 358 |
+
"Chimney": "Chimney",
|
| 359 |
+
"Misc": None,
|
| 360 |
+
"BaseCabinetRound": None,
|
| 361 |
+
"BaseCabinetTriangle": None,
|
| 362 |
+
"BaseCabinet": None,
|
| 363 |
+
"WallCabinet": None,
|
| 364 |
+
"Shower": "None",
|
| 365 |
+
"ShowerCab": "None",
|
| 366 |
+
"ShowerPlatform": "None",
|
| 367 |
+
"ShowerScreen": None,
|
| 368 |
+
"ShowerScreenRoundRight": None,
|
| 369 |
+
"ShowerScreenRoundLeft": None,
|
| 370 |
+
"Jacuzzi": None,
|
| 371 |
+
"WashingMachine": None,
|
| 372 |
+
"IntegratedStove": "ElectricalAppliance",
|
| 373 |
+
"Dishwasher": "ElectricalAppliance",
|
| 374 |
+
"GeneralAppliance": "ElectricalAppliance",
|
| 375 |
+
}
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
def complete_polygons(polygons, polygon_types):
|
| 379 |
+
new_polygons = []
|
| 380 |
+
new_types = []
|
| 381 |
+
for poly, poly_type in zip(polygons, polygon_types):
|
| 382 |
+
if len(poly) < 3:
|
| 383 |
+
print(f"Class {poly_type} has less than 3 points. Skipped!")
|
| 384 |
+
continue
|
| 385 |
+
poly_array = np.array(poly)
|
| 386 |
+
t = copy.copy(poly)
|
| 387 |
+
# append the beginning point
|
| 388 |
+
if len(poly_array) > 2 and (poly_array[0] != poly_array[-1]).any():
|
| 389 |
+
t.append(poly[0])
|
| 390 |
+
new_polygons.append(t)
|
| 391 |
+
new_types.append(poly_type)
|
| 392 |
+
|
| 393 |
+
return new_polygons, new_types
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
class House:
|
| 397 |
+
def __init__(self, path, height, width, icon_list=icons_selected, room_list=rooms_selected):
|
| 398 |
+
self.height = height
|
| 399 |
+
self.width = width
|
| 400 |
+
shape = height, width
|
| 401 |
+
svg = minidom.parse(path)
|
| 402 |
+
self.walls = np.empty((height, width), dtype=np.uint8)
|
| 403 |
+
self.walls.fill(0)
|
| 404 |
+
self.wall_ids = np.empty((height, width), dtype=np.uint8)
|
| 405 |
+
self.wall_ids.fill(0)
|
| 406 |
+
self.icons = np.zeros((height, width), dtype=np.uint8)
|
| 407 |
+
# junction_id = 0
|
| 408 |
+
wall_id = 1
|
| 409 |
+
self.wall_ends = []
|
| 410 |
+
self.wall_objs = []
|
| 411 |
+
self.icon_types = []
|
| 412 |
+
self.room_types = []
|
| 413 |
+
self.icon_corners = {"upper_left": [], "upper_right": [], "lower_left": [], "lower_right": []}
|
| 414 |
+
self.opening_corners = {"left": [], "right": [], "up": [], "down": []}
|
| 415 |
+
self.representation = {"doors": [], "icons": [], "labels": [], "walls": []}
|
| 416 |
+
|
| 417 |
+
self.icon_areas = []
|
| 418 |
+
self.wall_coords = []
|
| 419 |
+
self.icon_coords = []
|
| 420 |
+
|
| 421 |
+
for e in svg.getElementsByTagName("g"):
|
| 422 |
+
try:
|
| 423 |
+
if e.getAttribute("id") == "Wall":
|
| 424 |
+
wall = PolygonWall(e, wall_id, shape)
|
| 425 |
+
wall.rr, wall.cc = self._clip_outside(wall.rr, wall.cc)
|
| 426 |
+
self.wall_objs.append(wall)
|
| 427 |
+
self.walls[wall.rr, wall.cc] = room_list["Wall"]
|
| 428 |
+
self.wall_ids[wall.rr, wall.cc] = wall_id
|
| 429 |
+
self.wall_ends.append(wall.end_points)
|
| 430 |
+
|
| 431 |
+
Y, X = self._clip_outside(wall.Y, wall.X)
|
| 432 |
+
self.wall_coords.append([(x, y) for x, y in zip(X, Y)])
|
| 433 |
+
self.room_types.append(room_list["Wall"])
|
| 434 |
+
|
| 435 |
+
wall_id += 1
|
| 436 |
+
|
| 437 |
+
if e.getAttribute("id") == "Railing":
|
| 438 |
+
wall = PolygonWall(e, wall_id, shape)
|
| 439 |
+
wall.rr, wall.cc = self._clip_outside(wall.rr, wall.cc)
|
| 440 |
+
self.wall_objs.append(wall)
|
| 441 |
+
self.walls[wall.rr, wall.cc] = room_list["Railing"]
|
| 442 |
+
self.wall_ids[wall.rr, wall.cc] = wall_id
|
| 443 |
+
self.wall_ends.append(wall.end_points)
|
| 444 |
+
|
| 445 |
+
Y, X = self._clip_outside(wall.Y, wall.X)
|
| 446 |
+
self.wall_coords.append([(x, y) for x, y in zip(X, Y)])
|
| 447 |
+
self.room_types.append(room_list["Railing"])
|
| 448 |
+
|
| 449 |
+
wall_id += 1
|
| 450 |
+
|
| 451 |
+
except ValueError as k:
|
| 452 |
+
if str(k) != "small wall":
|
| 453 |
+
raise k
|
| 454 |
+
continue
|
| 455 |
+
|
| 456 |
+
if e.getAttribute("id") == "Window":
|
| 457 |
+
X, Y = get_points(e)
|
| 458 |
+
rr, cc = polygon(X, Y)
|
| 459 |
+
cc, rr = self._clip_outside(cc, rr)
|
| 460 |
+
direction = get_direction(X, Y)
|
| 461 |
+
locs = np.column_stack((X, Y))
|
| 462 |
+
if direction == "H":
|
| 463 |
+
left_index = np.argmin(locs[:, 0])
|
| 464 |
+
left1 = locs[left_index]
|
| 465 |
+
locs = np.delete(locs, left_index, axis=0)
|
| 466 |
+
left_index = np.argmin(locs[:, 0])
|
| 467 |
+
left2 = locs[left_index]
|
| 468 |
+
right = np.delete(locs, left_index, axis=0)
|
| 469 |
+
left = np.array([left1, left2])
|
| 470 |
+
|
| 471 |
+
point_left = left.mean(axis=0)
|
| 472 |
+
point_right = right.mean(axis=0)
|
| 473 |
+
self.opening_corners["left"].append(point_left)
|
| 474 |
+
self.opening_corners["right"].append(point_right)
|
| 475 |
+
|
| 476 |
+
door_rep = [[list(point_left), list(point_right)], ["door", 1, 1]]
|
| 477 |
+
self.representation["doors"].append(door_rep)
|
| 478 |
+
else:
|
| 479 |
+
up_index = np.argmin(locs[:, 1])
|
| 480 |
+
up1 = locs[up_index]
|
| 481 |
+
locs = np.delete(locs, up_index, axis=0)
|
| 482 |
+
up_index = np.argmin(locs[:, 1])
|
| 483 |
+
up2 = locs[up_index]
|
| 484 |
+
down = np.delete(locs, up_index, axis=0)
|
| 485 |
+
up = np.array([up1, up2])
|
| 486 |
+
|
| 487 |
+
point_up = up.mean(axis=0)
|
| 488 |
+
point_down = down.mean(axis=0)
|
| 489 |
+
self.opening_corners["up"].append(point_up)
|
| 490 |
+
self.opening_corners["down"].append(point_down)
|
| 491 |
+
|
| 492 |
+
door_rep = [[list(point_up), list(point_down)], ["door", 1, 1]]
|
| 493 |
+
self.representation["doors"].append(door_rep)
|
| 494 |
+
|
| 495 |
+
self.icons[cc, rr] = 1
|
| 496 |
+
self.icon_types.append(1)
|
| 497 |
+
|
| 498 |
+
Y, X = self._clip_outside(Y, X)
|
| 499 |
+
self.icon_coords.append([(x, y) for x, y in zip(X, Y)])
|
| 500 |
+
|
| 501 |
+
if e.getAttribute("id") == "Door":
|
| 502 |
+
# How to reperesent empty door space
|
| 503 |
+
X, Y = get_points(e)
|
| 504 |
+
rr, cc = polygon(X, Y)
|
| 505 |
+
cc, rr = self._clip_outside(cc, rr)
|
| 506 |
+
direction = get_direction(X, Y)
|
| 507 |
+
locs = np.column_stack((X, Y))
|
| 508 |
+
if direction == "H":
|
| 509 |
+
left_index = np.argmin(locs[:, 0])
|
| 510 |
+
left1 = locs[left_index]
|
| 511 |
+
locs = np.delete(locs, left_index, axis=0)
|
| 512 |
+
left_index = np.argmin(locs[:, 0])
|
| 513 |
+
left2 = locs[left_index]
|
| 514 |
+
right = np.delete(locs, left_index, axis=0)
|
| 515 |
+
left = np.array([left1, left2])
|
| 516 |
+
|
| 517 |
+
point_left = left.mean(axis=0)
|
| 518 |
+
point_right = right.mean(axis=0)
|
| 519 |
+
self.opening_corners["left"].append(left.mean(axis=0))
|
| 520 |
+
self.opening_corners["right"].append(right.mean(axis=0))
|
| 521 |
+
|
| 522 |
+
door_rep = [[list(point_left), list(point_right)], ["door", 1, 1]]
|
| 523 |
+
self.representation["doors"].append(door_rep)
|
| 524 |
+
else:
|
| 525 |
+
up_index = np.argmin(locs[:, 1])
|
| 526 |
+
up1 = locs[up_index]
|
| 527 |
+
locs = np.delete(locs, up_index, axis=0)
|
| 528 |
+
up_index = np.argmin(locs[:, 1])
|
| 529 |
+
up2 = locs[up_index]
|
| 530 |
+
down = np.delete(locs, up_index, axis=0)
|
| 531 |
+
up = np.array([up1, up2])
|
| 532 |
+
|
| 533 |
+
point_up = up.mean(axis=0)
|
| 534 |
+
point_down = down.mean(axis=0)
|
| 535 |
+
self.opening_corners["up"].append(up.mean(axis=0))
|
| 536 |
+
self.opening_corners["down"].append(down.mean(axis=0))
|
| 537 |
+
|
| 538 |
+
door_rep = [[list(point_up), list(point_down)], ["door", 1, 1]]
|
| 539 |
+
self.representation["doors"].append(door_rep)
|
| 540 |
+
|
| 541 |
+
self.icons[cc, rr] = 2
|
| 542 |
+
self.icon_types.append(2)
|
| 543 |
+
|
| 544 |
+
Y, X = self._clip_outside(Y, X)
|
| 545 |
+
self.icon_coords.append([(x, y) for x, y in zip(X, Y)])
|
| 546 |
+
|
| 547 |
+
if "FixedFurniture " in e.getAttribute("class"):
|
| 548 |
+
num = get_icon_number(e, icon_list)
|
| 549 |
+
if num is not None:
|
| 550 |
+
rr, cc, X, Y = get_icon(e)
|
| 551 |
+
# only four corner icons
|
| 552 |
+
if len(X) == 4:
|
| 553 |
+
locs = np.column_stack((X, Y))
|
| 554 |
+
up_left_index = locs.sum(axis=1).argmin()
|
| 555 |
+
self.icon_corners["upper_left"].append(locs[up_left_index])
|
| 556 |
+
up_left = list(locs[up_left_index])
|
| 557 |
+
locs = np.delete(locs, up_left_index, axis=0)
|
| 558 |
+
down_right_index = locs.sum(axis=1).argmax()
|
| 559 |
+
self.icon_corners["lower_right"].append(locs[down_right_index])
|
| 560 |
+
down_right = list(locs[down_right_index])
|
| 561 |
+
locs = np.delete(locs, down_right_index, axis=0)
|
| 562 |
+
up_right_index = locs[:, 1].argmin()
|
| 563 |
+
self.icon_corners["upper_right"].append(locs[up_right_index])
|
| 564 |
+
locs = np.delete(locs, up_right_index, axis=0)
|
| 565 |
+
self.icon_corners["lower_left"].append(locs[0])
|
| 566 |
+
|
| 567 |
+
icon_name = e.getAttribute("class").replace("FixedFurniture ", "").split(" ")[0]
|
| 568 |
+
icon_name = icon_name_map[icon_name]
|
| 569 |
+
|
| 570 |
+
icon_rep = [[up_left, down_right], [icon_name, 1, 1]]
|
| 571 |
+
self.representation["icons"].append(icon_rep)
|
| 572 |
+
|
| 573 |
+
rr, cc = self._clip_outside(rr, cc)
|
| 574 |
+
self.icon_areas.append(len(rr))
|
| 575 |
+
self.icons[rr, cc] = num
|
| 576 |
+
self.icon_types.append(num)
|
| 577 |
+
|
| 578 |
+
Y, X = self._clip_outside(Y, X)
|
| 579 |
+
self.icon_coords.append([(x, y) for x, y in zip(X, Y)])
|
| 580 |
+
|
| 581 |
+
if "Space " in e.getAttribute("class"):
|
| 582 |
+
num = get_room_number(e, room_list)
|
| 583 |
+
# rr, cc = get_polygon(e)
|
| 584 |
+
X, Y = get_points(e)
|
| 585 |
+
rr, cc = polygon(Y, X)
|
| 586 |
+
if len(rr) != 0:
|
| 587 |
+
rr, cc = self._clip_outside(rr, cc)
|
| 588 |
+
if len(rr) != 0 and len(cc) != 0:
|
| 589 |
+
self.walls[rr, cc] = num
|
| 590 |
+
self.room_types.append(num)
|
| 591 |
+
|
| 592 |
+
Y, X = self._clip_outside(Y, X)
|
| 593 |
+
self.wall_coords.append([(x, y) for x, y in zip(X, Y)])
|
| 594 |
+
|
| 595 |
+
rr_mean = int(round(np.mean(rr)))
|
| 596 |
+
cc_mean = int(round(np.mean(cc)))
|
| 597 |
+
center_box = [[rr_mean - 10, cc_mean - 10], [rr_mean + 10, cc_mean + 10]]
|
| 598 |
+
room_name = e.getAttribute("class").replace("Space ", "").split(" ")[0]
|
| 599 |
+
room_name = room_name_map[room_name]
|
| 600 |
+
self.representation["labels"].append([center_box, [room_name, 1, 1]])
|
| 601 |
+
|
| 602 |
+
# if "Stairs" in e.getAttribute("class"):
|
| 603 |
+
# for c in e.childNodes:
|
| 604 |
+
# if c.getAttribute("class") in ["Flight", "Winding"]:
|
| 605 |
+
# num = room_list["Stairs"]
|
| 606 |
+
# rr, cc = get_polygon(c)
|
| 607 |
+
# if len(rr) != 0:
|
| 608 |
+
# rr, cc = self._clip_outside(rr, cc)
|
| 609 |
+
# if len(rr) != 0 and len(cc) != 0:
|
| 610 |
+
# self.walls[rr, cc] = num
|
| 611 |
+
# self.room_types.append(num)
|
| 612 |
+
|
| 613 |
+
# rr_mean = int(round(np.mean(rr)))
|
| 614 |
+
# cc_mean = int(round(np.mean(cc)))
|
| 615 |
+
# center_box = [[rr_mean-10, cc_mean-10], [rr_mean+10, cc_mean+10]]
|
| 616 |
+
# room_name = "Stairs"
|
| 617 |
+
# # room_name = room_name_map[room_name]
|
| 618 |
+
# self.representation['labels'].append([center_box, [room_name, 1, 1]])
|
| 619 |
+
|
| 620 |
+
self.avg_wall_width = self.get_avg_wall_width()
|
| 621 |
+
|
| 622 |
+
self.new_walls = self.connect_walls(self.wall_objs)
|
| 623 |
+
|
| 624 |
+
for w in self.new_walls:
|
| 625 |
+
w.change_end_points()
|
| 626 |
+
|
| 627 |
+
for w in self.pillar_walls:
|
| 628 |
+
self.new_walls.append(w)
|
| 629 |
+
|
| 630 |
+
self.points = self.lines_to_points(self.width, self.height, self.new_walls, self.avg_wall_width)
|
| 631 |
+
self.points = self.merge_joints(self.points, self.avg_wall_width)
|
| 632 |
+
|
| 633 |
+
# walls to representation
|
| 634 |
+
for w in self.new_walls:
|
| 635 |
+
end_points = w.end_points.round().astype("int").tolist()
|
| 636 |
+
if w.name == "Wall":
|
| 637 |
+
self.representation["walls"].append([end_points, ["wall", 1, 1]])
|
| 638 |
+
else:
|
| 639 |
+
self.representation["walls"].append([end_points, ["wall", 2, 1]])
|
| 640 |
+
|
| 641 |
+
# append begining point at last pos
|
| 642 |
+
print("Complete room coords")
|
| 643 |
+
self.wall_coords, self.room_types = complete_polygons(self.wall_coords, self.room_types)
|
| 644 |
+
print("Complete icon coords")
|
| 645 |
+
self.icon_coords, self.icon_types = complete_polygons(self.icon_coords, self.icon_types)
|
| 646 |
+
|
| 647 |
+
def get_coords_and_labels(self):
|
| 648 |
+
assert len(self.wall_coords) == len(self.room_types)
|
| 649 |
+
assert len(self.icon_coords) == len(self.icon_types)
|
| 650 |
+
return self.wall_coords, self.room_types, self.icon_coords, self.icon_types
|
| 651 |
+
|
| 652 |
+
def get_tensor(self):
|
| 653 |
+
heatmaps = self.get_heatmaps()
|
| 654 |
+
wall_t = np.expand_dims(self.walls, axis=0)
|
| 655 |
+
icon_t = np.expand_dims(self.icons, axis=0)
|
| 656 |
+
tensor = np.concatenate((heatmaps, wall_t, icon_t), axis=0)
|
| 657 |
+
|
| 658 |
+
return tensor
|
| 659 |
+
|
| 660 |
+
def get_segmentation_tensor(self):
|
| 661 |
+
wall_t = np.expand_dims(self.walls, axis=0)
|
| 662 |
+
icon_t = np.expand_dims(self.icons, axis=0)
|
| 663 |
+
tensor = np.concatenate((wall_t, icon_t), axis=0)
|
| 664 |
+
|
| 665 |
+
return tensor
|
| 666 |
+
|
| 667 |
+
def get_heatmap_dict(self):
|
| 668 |
+
# init dict
|
| 669 |
+
heatmaps = {}
|
| 670 |
+
for i in range(21):
|
| 671 |
+
heatmaps[i] = []
|
| 672 |
+
|
| 673 |
+
for p in self.points:
|
| 674 |
+
cord, _, p_type = p
|
| 675 |
+
x = int(np.round(cord[0]))
|
| 676 |
+
y = int(np.round(cord[1]))
|
| 677 |
+
channel = self.get_number(p_type)
|
| 678 |
+
if y < self.height and x < self.width:
|
| 679 |
+
heatmaps[channel - 1] = heatmaps[channel - 1] + [(x, y)]
|
| 680 |
+
|
| 681 |
+
channel = 13
|
| 682 |
+
for i in self.opening_corners["left"]:
|
| 683 |
+
y = int(i[1])
|
| 684 |
+
x = int(i[0])
|
| 685 |
+
if y < self.height and x < self.width:
|
| 686 |
+
heatmaps[channel] = heatmaps[channel] + [(x, y)]
|
| 687 |
+
channel += 1
|
| 688 |
+
for i in self.opening_corners["right"]:
|
| 689 |
+
y = int(i[1])
|
| 690 |
+
x = int(i[0])
|
| 691 |
+
if y < self.height and x < self.width:
|
| 692 |
+
heatmaps[channel] = heatmaps[channel] + [(x, y)]
|
| 693 |
+
channel += 1
|
| 694 |
+
for i in self.opening_corners["up"]:
|
| 695 |
+
y = int(i[1])
|
| 696 |
+
x = int(i[0])
|
| 697 |
+
if y < self.height and x < self.width:
|
| 698 |
+
heatmaps[channel] = heatmaps[channel] + [(x, y)]
|
| 699 |
+
channel += 1
|
| 700 |
+
for i in self.opening_corners["down"]:
|
| 701 |
+
y = int(i[1])
|
| 702 |
+
x = int(i[0])
|
| 703 |
+
if y < self.height and x < self.width:
|
| 704 |
+
heatmaps[channel] = heatmaps[channel] + [(x, y)]
|
| 705 |
+
channel += 1
|
| 706 |
+
|
| 707 |
+
for i in self.icon_corners["upper_left"]:
|
| 708 |
+
y = int(i[1])
|
| 709 |
+
x = int(i[0])
|
| 710 |
+
if y < self.height and x < self.width:
|
| 711 |
+
heatmaps[channel] = heatmaps[channel] + [(x, y)]
|
| 712 |
+
channel += 1
|
| 713 |
+
for i in self.icon_corners["upper_right"]:
|
| 714 |
+
y = int(i[1])
|
| 715 |
+
x = int(i[0])
|
| 716 |
+
if y < self.height and x < self.width:
|
| 717 |
+
heatmaps[channel] = heatmaps[channel] + [(x, y)]
|
| 718 |
+
channel += 1
|
| 719 |
+
for i in self.icon_corners["lower_left"]:
|
| 720 |
+
y = int(i[1])
|
| 721 |
+
x = int(i[0])
|
| 722 |
+
if y < self.height and x < self.width:
|
| 723 |
+
heatmaps[channel] = heatmaps[channel] + [(x, y)]
|
| 724 |
+
channel += 1
|
| 725 |
+
for i in self.icon_corners["lower_right"]:
|
| 726 |
+
y = int(i[1])
|
| 727 |
+
x = int(i[0])
|
| 728 |
+
if y < self.height and x < self.width:
|
| 729 |
+
heatmaps[channel] = heatmaps[channel] + [(x, y)]
|
| 730 |
+
|
| 731 |
+
return heatmaps
|
| 732 |
+
|
| 733 |
+
def get_heatmaps(self):
|
| 734 |
+
heatmaps = np.zeros((21, self.height, self.width))
|
| 735 |
+
for p in self.points:
|
| 736 |
+
cord, _, p_type = p
|
| 737 |
+
x = int(np.round(cord[0]))
|
| 738 |
+
y = int(np.round(cord[1]))
|
| 739 |
+
channel = self.get_number(p_type)
|
| 740 |
+
if y < self.height and x < self.width:
|
| 741 |
+
heatmaps[channel - 1, y, x] = 1
|
| 742 |
+
|
| 743 |
+
channel = 13
|
| 744 |
+
for i in self.opening_corners["left"]:
|
| 745 |
+
y = int(i[1])
|
| 746 |
+
x = int(i[0])
|
| 747 |
+
if y < self.height and x < self.width:
|
| 748 |
+
heatmaps[channel, y, x] = 1
|
| 749 |
+
channel += 1
|
| 750 |
+
for i in self.opening_corners["right"]:
|
| 751 |
+
y = int(i[1])
|
| 752 |
+
x = int(i[0])
|
| 753 |
+
if y < self.height and x < self.width:
|
| 754 |
+
heatmaps[channel, y, x] = 1
|
| 755 |
+
channel += 1
|
| 756 |
+
for i in self.opening_corners["up"]:
|
| 757 |
+
y = int(i[1])
|
| 758 |
+
x = int(i[0])
|
| 759 |
+
if y < self.height and x < self.width:
|
| 760 |
+
heatmaps[channel, y, x] = 1
|
| 761 |
+
channel += 1
|
| 762 |
+
for i in self.opening_corners["down"]:
|
| 763 |
+
y = int(i[1])
|
| 764 |
+
x = int(i[0])
|
| 765 |
+
if y < self.height and x < self.width:
|
| 766 |
+
heatmaps[channel, y, x] = 1
|
| 767 |
+
channel += 1
|
| 768 |
+
|
| 769 |
+
for i in self.icon_corners["upper_left"]:
|
| 770 |
+
y = int(i[1])
|
| 771 |
+
x = int(i[0])
|
| 772 |
+
if y < self.height and x < self.width:
|
| 773 |
+
heatmaps[channel, y, x] = 1
|
| 774 |
+
channel += 1
|
| 775 |
+
for i in self.icon_corners["upper_right"]:
|
| 776 |
+
y = int(i[1])
|
| 777 |
+
x = int(i[0])
|
| 778 |
+
if y < self.height and x < self.width:
|
| 779 |
+
heatmaps[channel, y, x] = 1
|
| 780 |
+
channel += 1
|
| 781 |
+
for i in self.icon_corners["lower_left"]:
|
| 782 |
+
y = int(i[1])
|
| 783 |
+
x = int(i[0])
|
| 784 |
+
if y < self.height and x < self.width:
|
| 785 |
+
heatmaps[channel, y, x] = 1
|
| 786 |
+
channel += 1
|
| 787 |
+
for i in self.icon_corners["lower_right"]:
|
| 788 |
+
y = int(i[1])
|
| 789 |
+
x = int(i[0])
|
| 790 |
+
if y < self.height and x < self.width:
|
| 791 |
+
heatmaps[channel, y, x] = 1
|
| 792 |
+
|
| 793 |
+
kernel = get_gaussian2D(13)
|
| 794 |
+
for i, h in enumerate(heatmaps):
|
| 795 |
+
heatmaps[i] = cv2.filter2D(h, -1, kernel)
|
| 796 |
+
|
| 797 |
+
return heatmaps
|
| 798 |
+
|
| 799 |
+
def _clip_outside(self, rr, cc):
|
| 800 |
+
s = np.column_stack((rr, cc))
|
| 801 |
+
s = s[s[:, 0] < self.height]
|
| 802 |
+
s = s[s[:, 1] < self.width]
|
| 803 |
+
|
| 804 |
+
return s[:, 0], s[:, 1]
|
| 805 |
+
|
| 806 |
+
def lines_to_points(self, width, height, walls, lineWidth):
|
| 807 |
+
lines = [h.end_points for h in walls]
|
| 808 |
+
|
| 809 |
+
points = []
|
| 810 |
+
usedLinePointMask = []
|
| 811 |
+
|
| 812 |
+
for lineIndex, line in enumerate(lines):
|
| 813 |
+
usedLinePointMask.append([False, False])
|
| 814 |
+
|
| 815 |
+
for lineIndex_1, wall_1 in enumerate(walls):
|
| 816 |
+
line_1 = wall_1.end_points
|
| 817 |
+
|
| 818 |
+
lineDim_1 = self.get_lineDim(line_1, 1)
|
| 819 |
+
if lineDim_1 <= -1:
|
| 820 |
+
# If wall is diagonal we skip
|
| 821 |
+
continue
|
| 822 |
+
|
| 823 |
+
fixedValue_1 = (line_1[0][1 - lineDim_1] + line_1[1][1 - lineDim_1]) / 2
|
| 824 |
+
for lineIndex_2, wall_2 in enumerate(walls):
|
| 825 |
+
line_2 = wall_2.end_points
|
| 826 |
+
|
| 827 |
+
if lineIndex_2 <= lineIndex_1:
|
| 828 |
+
continue
|
| 829 |
+
|
| 830 |
+
lineDim_2 = self.get_lineDim(line_2, 1)
|
| 831 |
+
if lineDim_2 + lineDim_1 != 1:
|
| 832 |
+
# if walls have the same direction we skip
|
| 833 |
+
continue
|
| 834 |
+
|
| 835 |
+
fixedValue_2 = (line_2[0][1 - lineDim_2] + line_2[1][1 - lineDim_2]) / 2
|
| 836 |
+
lineWidth = max(wall_1.max_width, wall_2.max_width)
|
| 837 |
+
nearestPair, minDistance = self.findNearestJunctionPair(line_1, line_2, lineWidth)
|
| 838 |
+
|
| 839 |
+
if minDistance <= lineWidth:
|
| 840 |
+
pointIndex_1 = nearestPair[0]
|
| 841 |
+
pointIndex_2 = nearestPair[1]
|
| 842 |
+
if pointIndex_1 > -1 and pointIndex_2 > -1:
|
| 843 |
+
point = [None, None]
|
| 844 |
+
point[lineDim_1] = fixedValue_2
|
| 845 |
+
point[lineDim_2] = fixedValue_1
|
| 846 |
+
side = [None, None]
|
| 847 |
+
side[lineDim_1] = line_1[1 - pointIndex_1][lineDim_1] - fixedValue_2
|
| 848 |
+
side[lineDim_2] = line_2[1 - pointIndex_2][lineDim_2] - fixedValue_1
|
| 849 |
+
|
| 850 |
+
if side[0] < 0 and side[1] < 0:
|
| 851 |
+
points.append([point, point, ["point", 2, 1]])
|
| 852 |
+
elif side[0] > 0 and side[1] < 0:
|
| 853 |
+
points.append([point, point, ["point", 2, 2]])
|
| 854 |
+
elif side[0] > 0 and side[1] > 0:
|
| 855 |
+
points.append([point, point, ["point", 2, 3]])
|
| 856 |
+
elif side[0] < 0 and side[1] > 0:
|
| 857 |
+
points.append([point, point, ["point", 2, 4]])
|
| 858 |
+
|
| 859 |
+
usedLinePointMask[lineIndex_1][pointIndex_1] = True
|
| 860 |
+
usedLinePointMask[lineIndex_2][pointIndex_2] = True
|
| 861 |
+
elif (pointIndex_1 > -1 and pointIndex_2 == -1) or (pointIndex_1 == -1 and pointIndex_2 > -1):
|
| 862 |
+
if pointIndex_1 > -1:
|
| 863 |
+
lineDim = lineDim_1
|
| 864 |
+
pointIndex = pointIndex_1
|
| 865 |
+
fixedValue = fixedValue_2
|
| 866 |
+
pointValue = line_1[pointIndex_1][1 - lineDim_1]
|
| 867 |
+
usedLinePointMask[lineIndex_1][pointIndex_1] = True
|
| 868 |
+
else:
|
| 869 |
+
lineDim = lineDim_2
|
| 870 |
+
pointIndex = pointIndex_2
|
| 871 |
+
fixedValue = fixedValue_1
|
| 872 |
+
pointValue = line_2[pointIndex_2][1 - lineDim_2]
|
| 873 |
+
usedLinePointMask[lineIndex_2][pointIndex_2] = True
|
| 874 |
+
|
| 875 |
+
point = [None, None]
|
| 876 |
+
point[lineDim] = fixedValue
|
| 877 |
+
point[1 - lineDim] = pointValue
|
| 878 |
+
|
| 879 |
+
if pointIndex == 0:
|
| 880 |
+
if lineDim == 0:
|
| 881 |
+
points.append([point, point, ["point", 3, 4]])
|
| 882 |
+
else:
|
| 883 |
+
points.append([point, point, ["point", 3, 1]])
|
| 884 |
+
else:
|
| 885 |
+
if lineDim == 0:
|
| 886 |
+
points.append([point, point, ["point", 3, 2]])
|
| 887 |
+
else:
|
| 888 |
+
points.append([point, point, ["point", 3, 3]])
|
| 889 |
+
|
| 890 |
+
elif (
|
| 891 |
+
line_1[0][lineDim_1] < fixedValue_2
|
| 892 |
+
and line_1[1][lineDim_1] > fixedValue_2
|
| 893 |
+
and line_2[0][lineDim_2] < fixedValue_1
|
| 894 |
+
and line_2[1][lineDim_2] > fixedValue_1
|
| 895 |
+
):
|
| 896 |
+
point = [None, None]
|
| 897 |
+
point[lineDim_1] = fixedValue_2
|
| 898 |
+
point[lineDim_2] = fixedValue_1
|
| 899 |
+
points.append([point, point, ["point", 4, 1]])
|
| 900 |
+
|
| 901 |
+
for lineIndex, pointMask in enumerate(usedLinePointMask):
|
| 902 |
+
lineDim = self.get_lineDim(lines[lineIndex], 1)
|
| 903 |
+
for pointIndex in range(2):
|
| 904 |
+
if pointMask[pointIndex] is True:
|
| 905 |
+
continue
|
| 906 |
+
point = [lines[lineIndex][pointIndex][0], lines[lineIndex][pointIndex][1]]
|
| 907 |
+
if pointIndex == 0:
|
| 908 |
+
if lineDim == 0:
|
| 909 |
+
points.append([point, point, ["point", 1, 4]])
|
| 910 |
+
elif lineDim == 1:
|
| 911 |
+
points.append([point, point, ["point", 1, 1]])
|
| 912 |
+
else:
|
| 913 |
+
if lineDim == 0:
|
| 914 |
+
points.append([point, point, ["point", 1, 2]])
|
| 915 |
+
elif lineDim == 1:
|
| 916 |
+
points.append([point, point, ["point", 1, 3]])
|
| 917 |
+
|
| 918 |
+
return points
|
| 919 |
+
|
| 920 |
+
def _pointId2index(self, g, t):
|
| 921 |
+
g_ = g - 1
|
| 922 |
+
t_ = t - 1
|
| 923 |
+
k = g_ * 4 + t_
|
| 924 |
+
return k
|
| 925 |
+
|
| 926 |
+
def _index2pointId(self, k):
|
| 927 |
+
g = k // 4 + 1
|
| 928 |
+
t = k % 4 + 1
|
| 929 |
+
return [g, t]
|
| 930 |
+
|
| 931 |
+
def _are_close(self, p1, p2, width):
|
| 932 |
+
return calc_distance(p1, p2) < width
|
| 933 |
+
|
| 934 |
+
def merge_joints(self, points, wall_width):
|
| 935 |
+
lookuptable = {}
|
| 936 |
+
lookuptable[0] = {0: 0, 1: 7, 2: None, 3: 6, 4: 9, 5: 11, 6: 6, 7: 7, 8: 8, 9: 9, 10: 12, 11: 11, 12: 12}
|
| 937 |
+
lookuptable[1] = {0: 7, 1: 1, 2: 4, 3: None, 4: 4, 5: 10, 6: 8, 7: 7, 8: 8, 9: 9, 10: 10, 11: 12, 12: 12}
|
| 938 |
+
lookuptable[2] = {0: None, 1: 4, 2: 2, 3: 5, 4: 4, 5: 5, 6: 11, 7: 9, 8: 12, 9: 9, 10: 10, 11: 11, 12: 12}
|
| 939 |
+
lookuptable[3] = {0: 6, 1: None, 2: 5, 3: 3, 4: 10, 5: 5, 6: 6, 7: 8, 8: 8, 9: 12, 10: 10, 11: 11, 12: 12}
|
| 940 |
+
lookuptable[4] = {0: 9, 1: 4, 2: 4, 3: 10, 4: 4, 5: 10, 6: 12, 7: 9, 8: 12, 9: 9, 10: 10, 11: 12, 12: 12}
|
| 941 |
+
lookuptable[5] = {0: 11, 1: 10, 2: 5, 3: 5, 4: 10, 5: 5, 6: 11, 7: 12, 8: 12, 9: 12, 10: 10, 11: 11, 12: 12}
|
| 942 |
+
lookuptable[6] = {0: 6, 1: 8, 2: 11, 3: 6, 4: 12, 5: 11, 6: 6, 7: 8, 8: 8, 9: 12, 10: 12, 11: 11, 12: 12}
|
| 943 |
+
lookuptable[7] = {0: 7, 1: 7, 2: 9, 3: 8, 4: 9, 5: 12, 6: 8, 7: 7, 8: 8, 9: 9, 10: 12, 11: 12, 12: 12}
|
| 944 |
+
lookuptable[8] = {0: 8, 1: 8, 2: 12, 3: 8, 4: 12, 5: 12, 6: 8, 7: 8, 8: 8, 9: 12, 10: 12, 11: 12, 12: 12}
|
| 945 |
+
lookuptable[9] = {0: 9, 1: 9, 2: 9, 3: 12, 4: 9, 5: 12, 6: 12, 7: 9, 8: 12, 9: 9, 10: 12, 11: 12, 12: 12}
|
| 946 |
+
lookuptable[10] = {
|
| 947 |
+
0: 12,
|
| 948 |
+
1: 10,
|
| 949 |
+
2: 10,
|
| 950 |
+
3: 10,
|
| 951 |
+
4: 10,
|
| 952 |
+
5: 10,
|
| 953 |
+
6: 12,
|
| 954 |
+
7: 12,
|
| 955 |
+
8: 12,
|
| 956 |
+
9: 12,
|
| 957 |
+
10: 10,
|
| 958 |
+
11: 12,
|
| 959 |
+
12: 12,
|
| 960 |
+
}
|
| 961 |
+
lookuptable[11] = {
|
| 962 |
+
0: 11,
|
| 963 |
+
1: 12,
|
| 964 |
+
2: 11,
|
| 965 |
+
3: 11,
|
| 966 |
+
4: 12,
|
| 967 |
+
5: 11,
|
| 968 |
+
6: 11,
|
| 969 |
+
7: 12,
|
| 970 |
+
8: 12,
|
| 971 |
+
9: 12,
|
| 972 |
+
10: 12,
|
| 973 |
+
11: 11,
|
| 974 |
+
12: 12,
|
| 975 |
+
}
|
| 976 |
+
lookuptable[12] = {
|
| 977 |
+
0: 12,
|
| 978 |
+
1: 12,
|
| 979 |
+
2: 12,
|
| 980 |
+
3: 12,
|
| 981 |
+
4: 12,
|
| 982 |
+
5: 12,
|
| 983 |
+
6: 12,
|
| 984 |
+
7: 12,
|
| 985 |
+
8: 12,
|
| 986 |
+
9: 12,
|
| 987 |
+
10: 12,
|
| 988 |
+
11: 12,
|
| 989 |
+
12: 12,
|
| 990 |
+
}
|
| 991 |
+
|
| 992 |
+
newPoints = []
|
| 993 |
+
merged = [False] * len(points)
|
| 994 |
+
for i, point1 in enumerate(points):
|
| 995 |
+
if merged[i] is False:
|
| 996 |
+
pool = [point1]
|
| 997 |
+
for j, point2 in enumerate(points):
|
| 998 |
+
if j != i and merged[j] is False and self._are_close(point1[0], point2[0], wall_width):
|
| 999 |
+
merged[j] = True
|
| 1000 |
+
pool.append(point2)
|
| 1001 |
+
|
| 1002 |
+
if len(pool) == 1:
|
| 1003 |
+
newPoints.append(point1)
|
| 1004 |
+
merged[i] = True
|
| 1005 |
+
else:
|
| 1006 |
+
p_ = pool[0]
|
| 1007 |
+
for point_id in range(1, len(pool)):
|
| 1008 |
+
merge_to_p = pool[point_id]
|
| 1009 |
+
|
| 1010 |
+
k_ = self._pointId2index(p_[2][1], p_[2][2])
|
| 1011 |
+
k_merge_to_p = self._pointId2index(merge_to_p[2][1], merge_to_p[2][2])
|
| 1012 |
+
|
| 1013 |
+
knew = lookuptable[k_][k_merge_to_p]
|
| 1014 |
+
if knew is None:
|
| 1015 |
+
continue
|
| 1016 |
+
|
| 1017 |
+
typenew = self._index2pointId(knew)
|
| 1018 |
+
p_ = [p_[0], p_[1], ["point", typenew[0], typenew[1]]]
|
| 1019 |
+
|
| 1020 |
+
newPoints.append(p_)
|
| 1021 |
+
|
| 1022 |
+
return newPoints
|
| 1023 |
+
|
| 1024 |
+
def get_avg_wall_width(self):
|
| 1025 |
+
res = 0
|
| 1026 |
+
for i, w in enumerate(self.wall_objs):
|
| 1027 |
+
res += w.max_width
|
| 1028 |
+
res = res / float(i)
|
| 1029 |
+
|
| 1030 |
+
return res
|
| 1031 |
+
|
| 1032 |
+
def connect_walls(self, walls):
|
| 1033 |
+
new_walls = []
|
| 1034 |
+
num_walls = len(walls)
|
| 1035 |
+
remaining_walls = list(range(1, num_walls + 1))
|
| 1036 |
+
|
| 1037 |
+
# getting pillars
|
| 1038 |
+
remaining_pillar_ids = []
|
| 1039 |
+
for p_id in range(1, num_walls + 1):
|
| 1040 |
+
p_wall = self.find_wall_by_id(p_id, walls)
|
| 1041 |
+
if p_wall.wall_is_pillar(self.avg_wall_width):
|
| 1042 |
+
for wall_id in range(1, num_walls + 1):
|
| 1043 |
+
wall = self.find_wall_by_id(wall_id, walls)
|
| 1044 |
+
if p_wall.merge_possible(wall):
|
| 1045 |
+
break
|
| 1046 |
+
else:
|
| 1047 |
+
remaining_walls.pop(remaining_walls.index(p_wall.id))
|
| 1048 |
+
remaining_pillar_ids.append(p_wall.id)
|
| 1049 |
+
|
| 1050 |
+
while len(remaining_walls) > 0:
|
| 1051 |
+
new_wall_id = remaining_walls.pop(0)
|
| 1052 |
+
new_wall = self.find_wall_by_id(new_wall_id, walls)
|
| 1053 |
+
|
| 1054 |
+
found = True
|
| 1055 |
+
while found:
|
| 1056 |
+
found = False
|
| 1057 |
+
for merge_wall_id in remaining_walls:
|
| 1058 |
+
merged = self.find_wall_by_id(merge_wall_id, walls)
|
| 1059 |
+
temp_wall = new_wall.merge_walls(merged)
|
| 1060 |
+
|
| 1061 |
+
if temp_wall is not None:
|
| 1062 |
+
remaining_walls.pop(remaining_walls.index(merged.id))
|
| 1063 |
+
new_wall = temp_wall
|
| 1064 |
+
found = True
|
| 1065 |
+
|
| 1066 |
+
new_walls.append(new_wall)
|
| 1067 |
+
|
| 1068 |
+
# connect pillars to walls
|
| 1069 |
+
new_wall_id = num_walls + 1
|
| 1070 |
+
self.pillar_walls = []
|
| 1071 |
+
for id in remaining_pillar_ids:
|
| 1072 |
+
w = self.find_wall_by_id(id, walls)
|
| 1073 |
+
pws = w.split_pillar_wall(new_wall_id, self.avg_wall_width)
|
| 1074 |
+
new_wall_id += 4
|
| 1075 |
+
for pw in pws:
|
| 1076 |
+
self.pillar_walls.append(pw)
|
| 1077 |
+
|
| 1078 |
+
return new_walls
|
| 1079 |
+
|
| 1080 |
+
def get_number(self, x):
|
| 1081 |
+
return (x[1] - 1) * 4 + x[2]
|
| 1082 |
+
|
| 1083 |
+
def get_lineDim(self, line, lineWidth):
|
| 1084 |
+
lineWidth = lineWidth or 1
|
| 1085 |
+
if abs(line[0][0] - line[1][0]) > abs(line[0][1] - line[1][1]) and abs(line[0][1] - line[1][1]) <= lineWidth:
|
| 1086 |
+
return 0
|
| 1087 |
+
elif abs(line[0][1] - line[1][1]) > abs(line[0][0] - line[1][0]) and abs(line[0][0] - line[1][0]) <= lineWidth:
|
| 1088 |
+
return 1
|
| 1089 |
+
else:
|
| 1090 |
+
return -1
|
| 1091 |
+
|
| 1092 |
+
def findNearestJunctionPair(self, line_1, line_2, gap):
|
| 1093 |
+
|
| 1094 |
+
minDistance = None
|
| 1095 |
+
for index_1 in range(0, 2):
|
| 1096 |
+
for index_2 in range(0, 2):
|
| 1097 |
+
distance = calc_distance(line_1[index_1], line_2[index_2])
|
| 1098 |
+
if minDistance is None or distance < minDistance:
|
| 1099 |
+
nearestPair = [index_1, index_2]
|
| 1100 |
+
minDistance = distance
|
| 1101 |
+
|
| 1102 |
+
if minDistance > gap:
|
| 1103 |
+
lineDim_1 = self.get_lineDim(line_1, 1)
|
| 1104 |
+
lineDim_2 = self.get_lineDim(line_2, 1)
|
| 1105 |
+
|
| 1106 |
+
if lineDim_1 + lineDim_2 == 1:
|
| 1107 |
+
fixedValue_1 = (line_1[0][1 - lineDim_1] + line_1[1][1 - lineDim_1]) / 2
|
| 1108 |
+
fixedValue_2 = (line_2[0][1 - lineDim_2] + line_2[1][1 - lineDim_2]) / 2
|
| 1109 |
+
|
| 1110 |
+
if line_2[0][lineDim_2] < fixedValue_1 and line_2[1][lineDim_2] > fixedValue_1:
|
| 1111 |
+
for index in range(2):
|
| 1112 |
+
distance = abs(line_1[index][lineDim_1] - fixedValue_2)
|
| 1113 |
+
if distance < minDistance:
|
| 1114 |
+
nearestPair = [index, -1]
|
| 1115 |
+
minDistance = distance
|
| 1116 |
+
|
| 1117 |
+
if line_1[0][lineDim_1] < fixedValue_2 and line_1[1][lineDim_1] > fixedValue_2:
|
| 1118 |
+
for index in range(2):
|
| 1119 |
+
distance = abs(line_2[index][lineDim_2] - fixedValue_1)
|
| 1120 |
+
if distance < minDistance:
|
| 1121 |
+
nearestPair = [-1, index]
|
| 1122 |
+
minDistance = distance
|
| 1123 |
+
|
| 1124 |
+
return nearestPair, minDistance
|
| 1125 |
+
|
| 1126 |
+
def find_wall_by_id(self, id, walls):
|
| 1127 |
+
for wall in walls:
|
| 1128 |
+
if wall.id == id:
|
| 1129 |
+
return wall
|
| 1130 |
+
|
| 1131 |
+
return None
|
data_preprocess/cubicasa5k/loaders.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# import lmdb
|
| 2 |
+
import pickle
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from house import House
|
| 8 |
+
from numpy import genfromtxt
|
| 9 |
+
from torch.utils.data import Dataset
|
| 10 |
+
|
| 11 |
+
ROOM_NAMES = {
|
| 12 |
+
0: "Background",
|
| 13 |
+
1: "Outdoor",
|
| 14 |
+
2: "Wall",
|
| 15 |
+
3: "Kitchen",
|
| 16 |
+
4: "Living Room",
|
| 17 |
+
5: "Bed Room",
|
| 18 |
+
6: "Bath",
|
| 19 |
+
7: "Entry",
|
| 20 |
+
8: "Railing",
|
| 21 |
+
9: "Storage",
|
| 22 |
+
10: "Garage",
|
| 23 |
+
11: "Undefined",
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
ICON_NAMES = {
|
| 27 |
+
0: "No Icon",
|
| 28 |
+
1: "Window",
|
| 29 |
+
2: "Door",
|
| 30 |
+
3: "Closet",
|
| 31 |
+
4: "Electrical Applience",
|
| 32 |
+
5: "Toilet",
|
| 33 |
+
6: "Sink",
|
| 34 |
+
7: "Sauna Bench",
|
| 35 |
+
8: "Fire Place",
|
| 36 |
+
9: "Bathtub",
|
| 37 |
+
10: "Chimney",
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class FloorplanSVG(Dataset):
|
| 42 |
+
def __init__(
|
| 43 |
+
self,
|
| 44 |
+
data_folder,
|
| 45 |
+
data_file,
|
| 46 |
+
is_transform=True,
|
| 47 |
+
augmentations=None,
|
| 48 |
+
img_norm=True,
|
| 49 |
+
format="txt",
|
| 50 |
+
original_size=False,
|
| 51 |
+
lmdb_folder="cubi_lmdb/",
|
| 52 |
+
):
|
| 53 |
+
self.img_norm = img_norm
|
| 54 |
+
self.is_transform = is_transform
|
| 55 |
+
self.augmentations = augmentations
|
| 56 |
+
self.get_data = None
|
| 57 |
+
self.original_size = original_size
|
| 58 |
+
self.image_file_name = "/F1_scaled.png"
|
| 59 |
+
self.org_image_file_name = "/F1_original.png"
|
| 60 |
+
self.svg_file_name = "/model.svg"
|
| 61 |
+
|
| 62 |
+
if format == "txt":
|
| 63 |
+
self.get_data = self.get_txt
|
| 64 |
+
# if format == 'lmdb':
|
| 65 |
+
# self.lmdb = lmdb.open(data_folder+lmdb_folder, readonly=True,
|
| 66 |
+
# max_readers=8, lock=False,
|
| 67 |
+
# readahead=True, meminit=False)
|
| 68 |
+
# self.get_data = self.get_lmdb
|
| 69 |
+
# self.is_transform = False
|
| 70 |
+
|
| 71 |
+
self.data_folder = data_folder
|
| 72 |
+
# Load txt file to list
|
| 73 |
+
self.folders = genfromtxt(data_folder + data_file, dtype="str")
|
| 74 |
+
|
| 75 |
+
def __len__(self):
|
| 76 |
+
"""__len__"""
|
| 77 |
+
return len(self.folders)
|
| 78 |
+
|
| 79 |
+
def __getitem__(self, index):
|
| 80 |
+
sample = self.get_data(index)
|
| 81 |
+
|
| 82 |
+
if self.augmentations is not None:
|
| 83 |
+
sample = self.augmentations(sample)
|
| 84 |
+
|
| 85 |
+
if self.is_transform:
|
| 86 |
+
sample = self.transform(sample)
|
| 87 |
+
|
| 88 |
+
return sample
|
| 89 |
+
|
| 90 |
+
def get_txt(self, index):
|
| 91 |
+
fplan = cv2.imread(self.data_folder + self.folders[index] + self.image_file_name)
|
| 92 |
+
fplan = cv2.cvtColor(fplan, cv2.COLOR_BGR2RGB) # correct color channels
|
| 93 |
+
height, width, nchannel = fplan.shape
|
| 94 |
+
fplan = np.moveaxis(fplan, -1, 0)
|
| 95 |
+
|
| 96 |
+
# Getting labels for segmentation and heatmaps
|
| 97 |
+
house = House(self.data_folder + self.folders[index] + self.svg_file_name, height, width)
|
| 98 |
+
# Combining them to one numpy tensor
|
| 99 |
+
label = torch.tensor(house.get_segmentation_tensor().astype(np.float32))
|
| 100 |
+
heatmaps = house.get_heatmap_dict()
|
| 101 |
+
room_polygons, room_types, icon_polygons, icon_types = house.get_coords_and_labels()
|
| 102 |
+
coef_width = 1
|
| 103 |
+
if self.original_size:
|
| 104 |
+
fplan = cv2.imread(self.data_folder + self.folders[index] + self.org_image_file_name)
|
| 105 |
+
fplan = cv2.cvtColor(fplan, cv2.COLOR_BGR2RGB) # correct color channels
|
| 106 |
+
height_org, width_org, nchannel = fplan.shape
|
| 107 |
+
fplan = np.moveaxis(fplan, -1, 0)
|
| 108 |
+
label = label.unsqueeze(0)
|
| 109 |
+
label = torch.nn.functional.interpolate(label, size=(height_org, width_org), mode="nearest")
|
| 110 |
+
label = label.squeeze(0)
|
| 111 |
+
|
| 112 |
+
coef_height = float(height_org) / float(height)
|
| 113 |
+
coef_width = float(width_org) / float(width)
|
| 114 |
+
for key, value in heatmaps.items():
|
| 115 |
+
heatmaps[key] = [(int(round(x * coef_width)), int(round(y * coef_height))) for x, y in value]
|
| 116 |
+
|
| 117 |
+
new_room_polygons = []
|
| 118 |
+
for poly in room_polygons:
|
| 119 |
+
new_room_polygons.append([(int(round(x * coef_width)), int(round(y * coef_height))) for x, y in poly])
|
| 120 |
+
room_polygons = new_room_polygons
|
| 121 |
+
|
| 122 |
+
new_icon_polygons = []
|
| 123 |
+
for poly in icon_polygons:
|
| 124 |
+
new_icon_polygons.append([(int(round(x * coef_width)), int(round(y * coef_height))) for x, y in poly])
|
| 125 |
+
icon_polygons = new_icon_polygons
|
| 126 |
+
|
| 127 |
+
img = torch.tensor(fplan.astype(np.float32))
|
| 128 |
+
|
| 129 |
+
sample = {
|
| 130 |
+
"image": img,
|
| 131 |
+
"label": label,
|
| 132 |
+
"folder": self.folders[index],
|
| 133 |
+
"heatmaps": heatmaps,
|
| 134 |
+
"scale": coef_width,
|
| 135 |
+
"room_polygon": room_polygons,
|
| 136 |
+
"room_type": room_types,
|
| 137 |
+
"icon_polygon": icon_polygons,
|
| 138 |
+
"icon_type": icon_types,
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
return sample
|
| 142 |
+
|
| 143 |
+
def get_lmdb(self, index):
|
| 144 |
+
key = self.folders[index].encode()
|
| 145 |
+
with self.lmdb.begin(write=False) as f:
|
| 146 |
+
data = f.get(key)
|
| 147 |
+
|
| 148 |
+
sample = pickle.loads(data)
|
| 149 |
+
return sample
|
| 150 |
+
|
| 151 |
+
def transform(self, sample):
|
| 152 |
+
fplan = sample["image"]
|
| 153 |
+
# Normalization values to range -1 and 1
|
| 154 |
+
fplan = 2 * (fplan / 255.0) - 1
|
| 155 |
+
|
| 156 |
+
sample["image"] = fplan
|
| 157 |
+
|
| 158 |
+
return sample
|
data_preprocess/cubicasa5k/plotting.py
ADDED
|
@@ -0,0 +1,820 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.path as mplp
|
| 2 |
+
import matplotlib.pyplot as plt
|
| 3 |
+
import numpy as np
|
| 4 |
+
from matplotlib import cm, colors
|
| 5 |
+
from shapely.geometry import Point, Polygon
|
| 6 |
+
from skimage import draw
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def discrete_cmap_furukawa():
|
| 10 |
+
"""create a colormap with N (N<15) discrete colors and register it"""
|
| 11 |
+
# define individual colors as hex values
|
| 12 |
+
cpool = [
|
| 13 |
+
"#696969",
|
| 14 |
+
"#b3de69",
|
| 15 |
+
"#ffffb3",
|
| 16 |
+
"#8dd3c7",
|
| 17 |
+
"#fdb462",
|
| 18 |
+
"#fccde5",
|
| 19 |
+
"#80b1d3",
|
| 20 |
+
"#d9d9d9",
|
| 21 |
+
"#fb8072",
|
| 22 |
+
"#577a4d",
|
| 23 |
+
"white",
|
| 24 |
+
"#000000",
|
| 25 |
+
"#e31a1c",
|
| 26 |
+
]
|
| 27 |
+
cmap3 = colors.ListedColormap(cpool, "rooms_furukawa")
|
| 28 |
+
cm.register_cmap(cmap=cmap3)
|
| 29 |
+
|
| 30 |
+
cpool = [
|
| 31 |
+
"#ede676",
|
| 32 |
+
"#8dd3c7",
|
| 33 |
+
"#b15928",
|
| 34 |
+
"#fdb462",
|
| 35 |
+
"#ffff99",
|
| 36 |
+
"#fccde5",
|
| 37 |
+
"#80b1d3",
|
| 38 |
+
"#d9d9d9",
|
| 39 |
+
"#fb8072",
|
| 40 |
+
"#696969",
|
| 41 |
+
"#577a4d",
|
| 42 |
+
"#e31a1c",
|
| 43 |
+
"#42ef59",
|
| 44 |
+
"#8c595a",
|
| 45 |
+
"#3131e5",
|
| 46 |
+
"#48e0e6",
|
| 47 |
+
"white",
|
| 48 |
+
]
|
| 49 |
+
cmap3 = colors.ListedColormap(cpool, "icons_furukawa")
|
| 50 |
+
cm.register_cmap(cmap=cmap3)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def drawJunction(h, point, point_type, width, height):
|
| 54 |
+
lineLength = 15
|
| 55 |
+
lineWidth = 10
|
| 56 |
+
x, y = point
|
| 57 |
+
# plt.text(x,y,str(index),fontsize=25,color='r')
|
| 58 |
+
if point_type == -1:
|
| 59 |
+
h.scatter(x, y, color="#6488ea")
|
| 60 |
+
###########################
|
| 61 |
+
# o
|
| 62 |
+
# | #6488ea soft blue
|
| 63 |
+
# | drawcode = [1,1]
|
| 64 |
+
#
|
| 65 |
+
###########################
|
| 66 |
+
if point_type == 0:
|
| 67 |
+
h.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="#6488ea")
|
| 68 |
+
# plt.scatter(x, y-10, c='k')
|
| 69 |
+
###########################
|
| 70 |
+
#
|
| 71 |
+
# ---o #6241c7 bluey purple
|
| 72 |
+
# drawcode = [1,2]
|
| 73 |
+
#
|
| 74 |
+
###########################
|
| 75 |
+
elif point_type == 1:
|
| 76 |
+
h.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="#6241c7")
|
| 77 |
+
# plt.scatter(x+10, y, c='k')
|
| 78 |
+
###########################
|
| 79 |
+
# |
|
| 80 |
+
# | drawcode = [1,3]
|
| 81 |
+
# o #056eee cerulean blue
|
| 82 |
+
#
|
| 83 |
+
###########################
|
| 84 |
+
elif point_type == 2:
|
| 85 |
+
h.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="#056eee")
|
| 86 |
+
# plt.scatter(x, y+10, c='k')
|
| 87 |
+
###########################
|
| 88 |
+
#
|
| 89 |
+
# drawcode = [1,4]
|
| 90 |
+
#
|
| 91 |
+
# o--- #004577 prussian blue
|
| 92 |
+
#
|
| 93 |
+
###########################
|
| 94 |
+
elif point_type == 3:
|
| 95 |
+
h.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="#004577")
|
| 96 |
+
# plt.scatter(x-10, y, c='k')
|
| 97 |
+
###########################
|
| 98 |
+
#
|
| 99 |
+
# |--- drawcode = [2,3]
|
| 100 |
+
# |
|
| 101 |
+
#
|
| 102 |
+
###########################
|
| 103 |
+
elif point_type == 6:
|
| 104 |
+
h.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="#04d8b2")
|
| 105 |
+
h.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="#04d8b2")
|
| 106 |
+
###########################
|
| 107 |
+
#
|
| 108 |
+
# ---|
|
| 109 |
+
# | drawcode = [2,4]
|
| 110 |
+
#
|
| 111 |
+
###########################
|
| 112 |
+
elif point_type == 7:
|
| 113 |
+
h.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="#cdfd02")
|
| 114 |
+
h.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="#cdfd02")
|
| 115 |
+
###########################
|
| 116 |
+
# |
|
| 117 |
+
# ---| drawcode = [2,1]
|
| 118 |
+
#
|
| 119 |
+
#
|
| 120 |
+
###########################
|
| 121 |
+
elif point_type == 4:
|
| 122 |
+
h.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="#ff81c0")
|
| 123 |
+
h.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="#ff81c0")
|
| 124 |
+
###########################
|
| 125 |
+
#
|
| 126 |
+
# |
|
| 127 |
+
# | drawcode = [2,2]
|
| 128 |
+
# --
|
| 129 |
+
#
|
| 130 |
+
###########################
|
| 131 |
+
elif point_type == 5:
|
| 132 |
+
h.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="#f97306")
|
| 133 |
+
h.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="#f97306")
|
| 134 |
+
###########################
|
| 135 |
+
#
|
| 136 |
+
# |
|
| 137 |
+
# |--- drawcode = [3,4]
|
| 138 |
+
# |
|
| 139 |
+
#
|
| 140 |
+
###########################
|
| 141 |
+
elif point_type == 11:
|
| 142 |
+
h.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="b")
|
| 143 |
+
h.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="b")
|
| 144 |
+
h.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="b")
|
| 145 |
+
###########################
|
| 146 |
+
#
|
| 147 |
+
# ---
|
| 148 |
+
# | drawcode = [3,1]
|
| 149 |
+
# |
|
| 150 |
+
#
|
| 151 |
+
###########################
|
| 152 |
+
elif point_type == 8:
|
| 153 |
+
h.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="y")
|
| 154 |
+
h.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="y")
|
| 155 |
+
h.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="y")
|
| 156 |
+
###########################
|
| 157 |
+
#
|
| 158 |
+
# |
|
| 159 |
+
# ---| drawcode = [3,2]
|
| 160 |
+
# |
|
| 161 |
+
#
|
| 162 |
+
###########################
|
| 163 |
+
elif point_type == 9:
|
| 164 |
+
h.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="r")
|
| 165 |
+
h.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="r")
|
| 166 |
+
h.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="r")
|
| 167 |
+
###########################
|
| 168 |
+
#
|
| 169 |
+
# |
|
| 170 |
+
# | drawcode = [3,3]
|
| 171 |
+
# ---
|
| 172 |
+
#
|
| 173 |
+
###########################
|
| 174 |
+
elif point_type == 10:
|
| 175 |
+
h.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="m")
|
| 176 |
+
h.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="m")
|
| 177 |
+
h.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="m")
|
| 178 |
+
###########################
|
| 179 |
+
#
|
| 180 |
+
# |
|
| 181 |
+
# --- drawcode = [4,1]
|
| 182 |
+
# |
|
| 183 |
+
#
|
| 184 |
+
###########################
|
| 185 |
+
elif point_type == 12:
|
| 186 |
+
h.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="k")
|
| 187 |
+
h.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="k")
|
| 188 |
+
h.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="k")
|
| 189 |
+
h.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="k")
|
| 190 |
+
|
| 191 |
+
lineLength = 10
|
| 192 |
+
lineWidth = 5
|
| 193 |
+
|
| 194 |
+
###########################
|
| 195 |
+
# o--- opening left
|
| 196 |
+
###########################
|
| 197 |
+
if point_type == 13:
|
| 198 |
+
h.plot([x], [y], "o", markersize=30, color="red")
|
| 199 |
+
h.plot([x], [y], "o", markersize=25, color="white")
|
| 200 |
+
h.text(x, y, "OL", fontsize=30, color="magenta")
|
| 201 |
+
###########################
|
| 202 |
+
# ---o opening right
|
| 203 |
+
###########################
|
| 204 |
+
elif point_type == 14:
|
| 205 |
+
h.plot([x], [y], "o", markersize=30, color="red")
|
| 206 |
+
h.plot([x], [y], "o", markersize=25, color="white")
|
| 207 |
+
h.text(x, y, "OR", fontsize=30, color="magenta")
|
| 208 |
+
###########################
|
| 209 |
+
# o opening up
|
| 210 |
+
# |
|
| 211 |
+
# |
|
| 212 |
+
###########################
|
| 213 |
+
elif point_type == 15:
|
| 214 |
+
h.plot([x], [y], "o", markersize=30, color="red")
|
| 215 |
+
h.plot([x], [y], "o", markersize=25, color="white")
|
| 216 |
+
h.text(x, y, "OU", fontsize=30, color="mediumblue")
|
| 217 |
+
###########################
|
| 218 |
+
# | opening down
|
| 219 |
+
# |
|
| 220 |
+
# o
|
| 221 |
+
###########################
|
| 222 |
+
elif point_type == 16:
|
| 223 |
+
h.plot([x], [y], "o", markersize=30, color="red")
|
| 224 |
+
h.plot([x], [y], "o", markersize=25, color="white")
|
| 225 |
+
h.text(x, y, "OD", fontsize=30, color="mediumblue")
|
| 226 |
+
|
| 227 |
+
###########################
|
| 228 |
+
#
|
| 229 |
+
# |--- drawcode = [2,3]
|
| 230 |
+
# |
|
| 231 |
+
#
|
| 232 |
+
###########################
|
| 233 |
+
elif point_type == 17:
|
| 234 |
+
h.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="indianred")
|
| 235 |
+
h.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="indianred")
|
| 236 |
+
###########################
|
| 237 |
+
#
|
| 238 |
+
# ---|
|
| 239 |
+
# | drawcode = [2,4]
|
| 240 |
+
#
|
| 241 |
+
###########################
|
| 242 |
+
elif point_type == 18:
|
| 243 |
+
h.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="darkred")
|
| 244 |
+
h.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="darkred")
|
| 245 |
+
###########################
|
| 246 |
+
#
|
| 247 |
+
# |
|
| 248 |
+
# | drawcode = [2,2]
|
| 249 |
+
# --
|
| 250 |
+
#
|
| 251 |
+
###########################
|
| 252 |
+
elif point_type == 19:
|
| 253 |
+
h.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="salmon")
|
| 254 |
+
h.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="salmon")
|
| 255 |
+
###########################
|
| 256 |
+
# |
|
| 257 |
+
# ---| drawcode = [2,1]
|
| 258 |
+
#
|
| 259 |
+
#
|
| 260 |
+
###########################
|
| 261 |
+
elif point_type == 20:
|
| 262 |
+
h.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="orangered")
|
| 263 |
+
h.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="orangered")
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def draw_junction_from_dict(point_dict, width, height, size=1, fontsize=30):
|
| 267 |
+
index = 0
|
| 268 |
+
markersize_large = 20 * size
|
| 269 |
+
markersize_small = 15 * size
|
| 270 |
+
for point_type, locations in point_dict.items():
|
| 271 |
+
for loc in locations:
|
| 272 |
+
x, y = loc
|
| 273 |
+
lineLength = 20 * size
|
| 274 |
+
lineWidth = 20 * size
|
| 275 |
+
# plt.text(x,y,str(index),fontsize=25,color='r')
|
| 276 |
+
###########################
|
| 277 |
+
# o
|
| 278 |
+
# | #6488ea soft blue
|
| 279 |
+
# | drawcode = [1,1]
|
| 280 |
+
#
|
| 281 |
+
###########################
|
| 282 |
+
if point_type == 0:
|
| 283 |
+
plt.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="#6488ea")
|
| 284 |
+
# plt.scatter(x, y-10, c='k')
|
| 285 |
+
###########################
|
| 286 |
+
#
|
| 287 |
+
# ---o #6241c7 bluey purple
|
| 288 |
+
# drawcode = [1,2]
|
| 289 |
+
#
|
| 290 |
+
###########################
|
| 291 |
+
elif point_type == 1:
|
| 292 |
+
plt.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="#6241c7")
|
| 293 |
+
# plt.scatter(x+10, y, c='k')
|
| 294 |
+
###########################
|
| 295 |
+
# |
|
| 296 |
+
# | drawcode = [1,3]
|
| 297 |
+
# o #056eee cerulean blue
|
| 298 |
+
#
|
| 299 |
+
###########################
|
| 300 |
+
elif point_type == 2:
|
| 301 |
+
plt.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="#056eee")
|
| 302 |
+
# plt.scatter(x, y+10, c='k')
|
| 303 |
+
###########################
|
| 304 |
+
#
|
| 305 |
+
# drawcode = [1,4]
|
| 306 |
+
#
|
| 307 |
+
# o--- #004577 prussian blue
|
| 308 |
+
#
|
| 309 |
+
###########################
|
| 310 |
+
elif point_type == 3:
|
| 311 |
+
plt.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="#004577")
|
| 312 |
+
# plt.scatter(x-10, y, c='k')
|
| 313 |
+
###########################
|
| 314 |
+
#
|
| 315 |
+
# |--- drawcode = [2,3]
|
| 316 |
+
# |
|
| 317 |
+
#
|
| 318 |
+
###########################
|
| 319 |
+
elif point_type == 6:
|
| 320 |
+
plt.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="#04d8b2")
|
| 321 |
+
plt.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="#04d8b2")
|
| 322 |
+
###########################
|
| 323 |
+
#
|
| 324 |
+
# ---|
|
| 325 |
+
# | drawcode = [2,4]
|
| 326 |
+
#
|
| 327 |
+
###########################
|
| 328 |
+
elif point_type == 7:
|
| 329 |
+
plt.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="#cdfd02")
|
| 330 |
+
plt.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="#cdfd02")
|
| 331 |
+
###########################
|
| 332 |
+
# |
|
| 333 |
+
# ---| drawcode = [2,1]
|
| 334 |
+
#
|
| 335 |
+
#
|
| 336 |
+
###########################
|
| 337 |
+
elif point_type == 4:
|
| 338 |
+
plt.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="#ff81c0")
|
| 339 |
+
plt.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="#ff81c0")
|
| 340 |
+
###########################
|
| 341 |
+
#
|
| 342 |
+
# |
|
| 343 |
+
# | drawcode = [2,2]
|
| 344 |
+
# --
|
| 345 |
+
#
|
| 346 |
+
###########################
|
| 347 |
+
elif point_type == 5:
|
| 348 |
+
plt.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="#f97306")
|
| 349 |
+
plt.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="#f97306")
|
| 350 |
+
###########################
|
| 351 |
+
#
|
| 352 |
+
# |
|
| 353 |
+
# |--- drawcode = [3,4]
|
| 354 |
+
# |
|
| 355 |
+
#
|
| 356 |
+
###########################
|
| 357 |
+
elif point_type == 11:
|
| 358 |
+
plt.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="b")
|
| 359 |
+
plt.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="b")
|
| 360 |
+
plt.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="b")
|
| 361 |
+
###########################
|
| 362 |
+
#
|
| 363 |
+
# ---
|
| 364 |
+
# | drawcode = [3,1]
|
| 365 |
+
# |
|
| 366 |
+
#
|
| 367 |
+
###########################
|
| 368 |
+
elif point_type == 8:
|
| 369 |
+
plt.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="y")
|
| 370 |
+
plt.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="y")
|
| 371 |
+
plt.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="y")
|
| 372 |
+
###########################
|
| 373 |
+
#
|
| 374 |
+
# |
|
| 375 |
+
# ---| drawcode = [3,2]
|
| 376 |
+
# |
|
| 377 |
+
#
|
| 378 |
+
###########################
|
| 379 |
+
elif point_type == 9:
|
| 380 |
+
plt.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="r")
|
| 381 |
+
plt.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="r")
|
| 382 |
+
plt.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="r")
|
| 383 |
+
###########################
|
| 384 |
+
#
|
| 385 |
+
# |
|
| 386 |
+
# | drawcode = [3,3]
|
| 387 |
+
# ---
|
| 388 |
+
#
|
| 389 |
+
###########################
|
| 390 |
+
elif point_type == 10:
|
| 391 |
+
plt.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="m")
|
| 392 |
+
plt.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="m")
|
| 393 |
+
plt.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="m")
|
| 394 |
+
###########################
|
| 395 |
+
#
|
| 396 |
+
# |
|
| 397 |
+
# --- drawcode = [4,1]
|
| 398 |
+
# |
|
| 399 |
+
#
|
| 400 |
+
###########################
|
| 401 |
+
elif point_type == 12:
|
| 402 |
+
plt.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="k")
|
| 403 |
+
plt.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="k")
|
| 404 |
+
plt.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="k")
|
| 405 |
+
plt.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="k")
|
| 406 |
+
|
| 407 |
+
lineLength = 15 * size
|
| 408 |
+
lineWidth = 15 * size
|
| 409 |
+
|
| 410 |
+
###########################
|
| 411 |
+
# o--- opening left
|
| 412 |
+
###########################
|
| 413 |
+
if point_type == 13:
|
| 414 |
+
plt.plot([x], [y], "o", markersize=markersize_large, color="red")
|
| 415 |
+
plt.plot([x], [y], "o", markersize=markersize_small, color="white")
|
| 416 |
+
plt.text(x, y, "OL", fontsize=fontsize, color="magenta")
|
| 417 |
+
###########################
|
| 418 |
+
# ---o opening right
|
| 419 |
+
###########################
|
| 420 |
+
elif point_type == 14:
|
| 421 |
+
plt.plot([x], [y], "o", markersize=markersize_large, color="red")
|
| 422 |
+
plt.plot([x], [y], "o", markersize=markersize_small, color="white")
|
| 423 |
+
plt.text(x, y, "OR", fontsize=fontsize, color="magenta")
|
| 424 |
+
###########################
|
| 425 |
+
# o opening up
|
| 426 |
+
# |
|
| 427 |
+
# |
|
| 428 |
+
###########################
|
| 429 |
+
elif point_type == 15:
|
| 430 |
+
plt.plot([x], [y], "o", markersize=markersize_large, color="red")
|
| 431 |
+
plt.plot([x], [y], "o", markersize=markersize_small, color="white")
|
| 432 |
+
plt.text(x, y, "OU", fontsize=fontsize, color="mediumblue")
|
| 433 |
+
###########################
|
| 434 |
+
# | opening down
|
| 435 |
+
# |
|
| 436 |
+
# o
|
| 437 |
+
###########################
|
| 438 |
+
elif point_type == 16:
|
| 439 |
+
plt.plot([x], [y], "o", markersize=markersize_large, color="red")
|
| 440 |
+
plt.plot([x], [y], "o", markersize=markersize_small, color="white")
|
| 441 |
+
plt.text(x, y, "OD", fontsize=fontsize, color="mediumblue")
|
| 442 |
+
|
| 443 |
+
###########################
|
| 444 |
+
#
|
| 445 |
+
# |--- drawcode = [2,3]
|
| 446 |
+
# |
|
| 447 |
+
#
|
| 448 |
+
###########################
|
| 449 |
+
elif point_type == 17:
|
| 450 |
+
plt.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="indianred")
|
| 451 |
+
plt.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="indianred")
|
| 452 |
+
###########################
|
| 453 |
+
#
|
| 454 |
+
# ---|
|
| 455 |
+
# | drawcode = [2,4]
|
| 456 |
+
#
|
| 457 |
+
###########################
|
| 458 |
+
elif point_type == 18:
|
| 459 |
+
plt.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="darkred")
|
| 460 |
+
plt.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="darkred")
|
| 461 |
+
###########################
|
| 462 |
+
#
|
| 463 |
+
# |
|
| 464 |
+
# | drawcode = [2,2]
|
| 465 |
+
# --
|
| 466 |
+
#
|
| 467 |
+
###########################
|
| 468 |
+
elif point_type == 19:
|
| 469 |
+
plt.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="salmon")
|
| 470 |
+
plt.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="salmon")
|
| 471 |
+
###########################
|
| 472 |
+
# |
|
| 473 |
+
# ---| drawcode = [2,1]
|
| 474 |
+
#
|
| 475 |
+
#
|
| 476 |
+
###########################
|
| 477 |
+
elif point_type == 20:
|
| 478 |
+
plt.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="orangered")
|
| 479 |
+
plt.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="orangered")
|
| 480 |
+
|
| 481 |
+
index += 1
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
def plot_pre_rec_4(instances, classes):
|
| 485 |
+
walls = ["Wall", "Railing"]
|
| 486 |
+
openings = ["Window", "Door"]
|
| 487 |
+
rooms = [
|
| 488 |
+
"Outdoor",
|
| 489 |
+
"Kitchen",
|
| 490 |
+
"Living Room",
|
| 491 |
+
"Bed Room",
|
| 492 |
+
"Entry",
|
| 493 |
+
"Dining",
|
| 494 |
+
"Storage",
|
| 495 |
+
"Garage",
|
| 496 |
+
"Undefined Room",
|
| 497 |
+
"Sauna",
|
| 498 |
+
"Fire Place",
|
| 499 |
+
"Bathtub",
|
| 500 |
+
"Chimney",
|
| 501 |
+
]
|
| 502 |
+
icons = [
|
| 503 |
+
"Bath",
|
| 504 |
+
"Closet",
|
| 505 |
+
"Electrical Appliance",
|
| 506 |
+
"Toilet",
|
| 507 |
+
"Shower",
|
| 508 |
+
"Sink",
|
| 509 |
+
"Sauna",
|
| 510 |
+
"Fire Place",
|
| 511 |
+
"Bathtub",
|
| 512 |
+
"Chimney",
|
| 513 |
+
]
|
| 514 |
+
|
| 515 |
+
def make_sub_plot(classes_to_plot):
|
| 516 |
+
plt.ylim([0.0, 1.0])
|
| 517 |
+
plt.xlim([0.0, 1.0])
|
| 518 |
+
plt.xlabel("Recall")
|
| 519 |
+
plt.ylabel("Precision")
|
| 520 |
+
indx = [classes.index(i) for i in classes_to_plot]
|
| 521 |
+
ins = instances[:, indx].sum(axis=1)
|
| 522 |
+
|
| 523 |
+
correct = ins[:, 0]
|
| 524 |
+
false_positive = ins[:, 2]
|
| 525 |
+
false_negatives = ins[:, 1]
|
| 526 |
+
precision = correct / (correct + false_positive)
|
| 527 |
+
recall = correct / (correct + false_negatives)
|
| 528 |
+
|
| 529 |
+
plt.step(recall[::-1], precision, color="b", alpha=0.2, where="post")
|
| 530 |
+
plt.fill_between(recall[::-1], precision, step="post", alpha=0.2, color="b")
|
| 531 |
+
|
| 532 |
+
plt.subplot(2, 2, 1)
|
| 533 |
+
plt.title("Walls")
|
| 534 |
+
make_sub_plot(walls)
|
| 535 |
+
plt.subplot(2, 2, 2)
|
| 536 |
+
plt.title("Openings")
|
| 537 |
+
make_sub_plot(openings)
|
| 538 |
+
plt.subplot(2, 2, 3)
|
| 539 |
+
plt.title("Rooms")
|
| 540 |
+
make_sub_plot(rooms)
|
| 541 |
+
plt.subplot(2, 2, 4)
|
| 542 |
+
plt.title("Icons")
|
| 543 |
+
make_sub_plot(icons)
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
def discrete_cmap():
|
| 547 |
+
"""create a colormap with N (N<15) discrete colors and register it"""
|
| 548 |
+
# define individual colors as hex values
|
| 549 |
+
cpool = [
|
| 550 |
+
"#DCDCDC",
|
| 551 |
+
"#b3de69",
|
| 552 |
+
"#000000",
|
| 553 |
+
"#8dd3c7",
|
| 554 |
+
"#fdb462",
|
| 555 |
+
"#fccde5",
|
| 556 |
+
"#80b1d3",
|
| 557 |
+
"#808080",
|
| 558 |
+
"#fb8072",
|
| 559 |
+
"#696969",
|
| 560 |
+
"#577a4d",
|
| 561 |
+
"#ffffb3",
|
| 562 |
+
]
|
| 563 |
+
cmap3 = colors.ListedColormap(cpool, "rooms")
|
| 564 |
+
cm.register_cmap(cmap=cmap3)
|
| 565 |
+
|
| 566 |
+
cpool = [
|
| 567 |
+
"#DCDCDC",
|
| 568 |
+
"#8dd3c7",
|
| 569 |
+
"#b15928",
|
| 570 |
+
"#fdb462",
|
| 571 |
+
"#ffff99",
|
| 572 |
+
"#fccde5",
|
| 573 |
+
"#80b1d3",
|
| 574 |
+
"#808080",
|
| 575 |
+
"#fb8072",
|
| 576 |
+
"#696969",
|
| 577 |
+
"#577a4d",
|
| 578 |
+
]
|
| 579 |
+
cmap3 = colors.ListedColormap(cpool, "icons")
|
| 580 |
+
cm.register_cmap(cmap=cmap3)
|
| 581 |
+
|
| 582 |
+
"""create a colormap with N (N<15) discrete colors and register it"""
|
| 583 |
+
# define individual colors as hex values
|
| 584 |
+
cpool = [
|
| 585 |
+
"#DCDCDC",
|
| 586 |
+
"#b3de69",
|
| 587 |
+
"#000000",
|
| 588 |
+
"#8dd3c7",
|
| 589 |
+
"#fdb462",
|
| 590 |
+
"#fccde5",
|
| 591 |
+
"#80b1d3",
|
| 592 |
+
"#808080",
|
| 593 |
+
"#fb8072",
|
| 594 |
+
"#696969",
|
| 595 |
+
"#577a4d",
|
| 596 |
+
"#ffffb3",
|
| 597 |
+
"d3d5d7",
|
| 598 |
+
]
|
| 599 |
+
cmap3 = colors.ListedColormap(cpool, "rooms_furu")
|
| 600 |
+
cm.register_cmap(cmap=cmap3)
|
| 601 |
+
|
| 602 |
+
cpool = [
|
| 603 |
+
"#DCDCDC",
|
| 604 |
+
"#8dd3c7",
|
| 605 |
+
"#b15928",
|
| 606 |
+
"#fdb462",
|
| 607 |
+
"#ffff99",
|
| 608 |
+
"#fccde5",
|
| 609 |
+
"#80b1d3",
|
| 610 |
+
"#808080",
|
| 611 |
+
"#fb8072",
|
| 612 |
+
"#696969",
|
| 613 |
+
"#577a4d",
|
| 614 |
+
]
|
| 615 |
+
cmap3 = colors.ListedColormap(cpool, "rooms_furu")
|
| 616 |
+
cm.register_cmap(cmap=cmap3)
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
def segmentation_plot(rooms_pred, icons_pred, rooms_label, icons_label):
|
| 620 |
+
room_classes = [
|
| 621 |
+
"Background",
|
| 622 |
+
"Outdoor",
|
| 623 |
+
"Wall",
|
| 624 |
+
"Kitchen",
|
| 625 |
+
"Living Room",
|
| 626 |
+
"Bed Room",
|
| 627 |
+
"Bath",
|
| 628 |
+
"Entry",
|
| 629 |
+
"Railing",
|
| 630 |
+
"Storage",
|
| 631 |
+
"Garage",
|
| 632 |
+
"Undefined",
|
| 633 |
+
]
|
| 634 |
+
icon_classes = [
|
| 635 |
+
"No Icon",
|
| 636 |
+
"Window",
|
| 637 |
+
"Door",
|
| 638 |
+
"Closet",
|
| 639 |
+
"Electrical Applience",
|
| 640 |
+
"Toilet",
|
| 641 |
+
"Sink",
|
| 642 |
+
"Sauna Bench",
|
| 643 |
+
"Fire Place",
|
| 644 |
+
"Bathtub",
|
| 645 |
+
"Chimney",
|
| 646 |
+
]
|
| 647 |
+
discrete_cmap() # custom colormap
|
| 648 |
+
|
| 649 |
+
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(30, 15))
|
| 650 |
+
axes[0].set_title("Room Ground Truth")
|
| 651 |
+
axes[0].imshow(rooms_label, cmap="rooms", vmin=0, vmax=len(room_classes) - 1)
|
| 652 |
+
|
| 653 |
+
axes[1].set_title("Room Prediction")
|
| 654 |
+
im = axes[1].imshow(rooms_pred, cmap="rooms", vmin=0, vmax=len(room_classes) - 1)
|
| 655 |
+
|
| 656 |
+
cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
|
| 657 |
+
cbar = fig.colorbar(im, cax=cbar_ax, ticks=np.arange(12) + 0.5)
|
| 658 |
+
|
| 659 |
+
fig.subplots_adjust(right=0.8)
|
| 660 |
+
cbar.ax.set_yticklabels(room_classes)
|
| 661 |
+
plt.show()
|
| 662 |
+
|
| 663 |
+
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(30, 15))
|
| 664 |
+
axes[0].set_title("Icon Ground Truth")
|
| 665 |
+
axes[0].imshow(icons_label, cmap="icons", vmin=0, vmax=len(icon_classes) - 1)
|
| 666 |
+
|
| 667 |
+
axes[1].set_title("Icon Prediction")
|
| 668 |
+
im = axes[1].imshow(icons_pred, cmap="icons", vmin=0, vmax=len(icon_classes) - 1)
|
| 669 |
+
|
| 670 |
+
cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
|
| 671 |
+
cbar = fig.colorbar(im, cax=cbar_ax, ticks=np.arange(11) + 0.5)
|
| 672 |
+
|
| 673 |
+
fig.subplots_adjust(right=0.8)
|
| 674 |
+
cbar.ax.set_yticklabels(icon_classes)
|
| 675 |
+
plt.show()
|
| 676 |
+
|
| 677 |
+
|
| 678 |
+
def polygons_to_image(polygons, types, room_polygons, room_types, height, width):
|
| 679 |
+
pol_room_seg = np.zeros((height, width))
|
| 680 |
+
pol_icon_seg = np.zeros((height, width))
|
| 681 |
+
|
| 682 |
+
for i, pol in enumerate(room_polygons):
|
| 683 |
+
mask = shp_mask(pol, np.arange(width), np.arange(height))
|
| 684 |
+
|
| 685 |
+
# jj, ii = draw.polygon(pol[:, 1], pol[:, 0])
|
| 686 |
+
pol_room_seg[mask] = room_types[i]["class"]
|
| 687 |
+
|
| 688 |
+
for i, pol in enumerate(polygons):
|
| 689 |
+
jj, ii = draw.polygon(pol[:, 1], pol[:, 0])
|
| 690 |
+
if types[i]["type"] == "wall":
|
| 691 |
+
pol_room_seg[jj, ii] = types[i]["class"]
|
| 692 |
+
else:
|
| 693 |
+
pol_icon_seg[jj, ii] = types[i]["class"]
|
| 694 |
+
|
| 695 |
+
return pol_room_seg, pol_icon_seg
|
| 696 |
+
|
| 697 |
+
|
| 698 |
+
def plot_room(r, name, n_classes=12):
|
| 699 |
+
discrete_cmap() # custom colormap
|
| 700 |
+
plt.figure(figsize=(40, 30))
|
| 701 |
+
plt.axis("off")
|
| 702 |
+
plt.tight_layout()
|
| 703 |
+
plt.imshow(r, cmap="rooms", vmin=0, vmax=n_classes - 1)
|
| 704 |
+
plt.savefig(name + ".png", format="png")
|
| 705 |
+
plt.show()
|
| 706 |
+
|
| 707 |
+
|
| 708 |
+
def plot_icon(i, name, n_classes=11):
|
| 709 |
+
discrete_cmap() # custom colormap
|
| 710 |
+
plt.figure(figsize=(40, 30))
|
| 711 |
+
plt.axis("off")
|
| 712 |
+
plt.tight_layout()
|
| 713 |
+
plt.imshow(i, cmap="icons", vmin=0, vmax=n_classes - 1)
|
| 714 |
+
plt.savefig(name + ".png", format="png")
|
| 715 |
+
plt.show()
|
| 716 |
+
|
| 717 |
+
|
| 718 |
+
def plot_heatmaps(h, name):
|
| 719 |
+
for index, i in enumerate(h):
|
| 720 |
+
plt.figure(figsize=(40, 30))
|
| 721 |
+
plt.axis("off")
|
| 722 |
+
plt.tight_layout()
|
| 723 |
+
plt.imshow(i, cmap="Reds", vmin=0, vmax=1)
|
| 724 |
+
plt.savefig(name + str(index) + ".png", format="png")
|
| 725 |
+
plt.show()
|
| 726 |
+
|
| 727 |
+
|
| 728 |
+
def outline_to_mask(line, x, y):
|
| 729 |
+
"""Create mask from outline contour
|
| 730 |
+
|
| 731 |
+
Parameters
|
| 732 |
+
----------
|
| 733 |
+
line: array-like (N, 2)
|
| 734 |
+
x, y: 1-D grid coordinates (input for meshgrid)
|
| 735 |
+
|
| 736 |
+
Returns
|
| 737 |
+
-------
|
| 738 |
+
mask : 2-D boolean array (True inside)
|
| 739 |
+
|
| 740 |
+
Examples
|
| 741 |
+
--------
|
| 742 |
+
>>> from shapely.geometry import Point
|
| 743 |
+
>>> poly = Point(0,0).buffer(1)
|
| 744 |
+
>>> x = np.linspace(-5,5,100)
|
| 745 |
+
>>> y = np.linspace(-5,5,100)
|
| 746 |
+
>>> mask = outline_to_mask(poly.boundary, x, y)
|
| 747 |
+
"""
|
| 748 |
+
mpath = mplp.Path(line)
|
| 749 |
+
X, Y = np.meshgrid(x, y)
|
| 750 |
+
points = np.array((X.flatten(), Y.flatten())).T
|
| 751 |
+
mask = mpath.contains_points(points).reshape(X.shape)
|
| 752 |
+
return mask
|
| 753 |
+
|
| 754 |
+
|
| 755 |
+
def _grid_bbox(x, y):
|
| 756 |
+
dx = dy = 0
|
| 757 |
+
return x[0] - dx / 2, x[-1] + dx / 2, y[0] - dy / 2, y[-1] + dy / 2
|
| 758 |
+
|
| 759 |
+
|
| 760 |
+
def _bbox_to_rect(bbox):
|
| 761 |
+
l, r, b, t = bbox
|
| 762 |
+
return Polygon([(l, b), (r, b), (r, t), (l, t)])
|
| 763 |
+
|
| 764 |
+
|
| 765 |
+
def shp_mask(shp, x, y, m=None):
|
| 766 |
+
"""
|
| 767 |
+
Adapted from code written by perrette
|
| 768 |
+
form: https://gist.github.com/perrette/a78f99b76aed54b6babf3597e0b331f8
|
| 769 |
+
Use recursive sub-division of space and shapely contains method to create a raster mask on a regular grid.
|
| 770 |
+
|
| 771 |
+
Parameters
|
| 772 |
+
----------
|
| 773 |
+
shp : shapely's Polygon (or whatever with a "contains" method and intersects method)
|
| 774 |
+
x, y : 1-D numpy arrays defining a regular grid
|
| 775 |
+
m : mask to fill, optional (will be created otherwise)
|
| 776 |
+
|
| 777 |
+
Returns
|
| 778 |
+
-------
|
| 779 |
+
m : boolean 2-D array, True inside shape.
|
| 780 |
+
|
| 781 |
+
Examples
|
| 782 |
+
--------
|
| 783 |
+
>>> from shapely.geometry import Point
|
| 784 |
+
>>> poly = Point(0,0).buffer(1)
|
| 785 |
+
>>> x = np.linspace(-5,5,100)
|
| 786 |
+
>>> y = np.linspace(-5,5,100)
|
| 787 |
+
>>> mask = shp_mask(poly, x, y)
|
| 788 |
+
"""
|
| 789 |
+
rect = _bbox_to_rect(_grid_bbox(x, y))
|
| 790 |
+
|
| 791 |
+
if m is None:
|
| 792 |
+
m = np.zeros((y.size, x.size), dtype=bool)
|
| 793 |
+
|
| 794 |
+
if not shp.intersects(rect):
|
| 795 |
+
m[:] = False
|
| 796 |
+
|
| 797 |
+
elif shp.contains(rect):
|
| 798 |
+
m[:] = True
|
| 799 |
+
|
| 800 |
+
else:
|
| 801 |
+
k, l = m.shape
|
| 802 |
+
|
| 803 |
+
if k == 1 and l == 1:
|
| 804 |
+
m[:] = shp.contains(Point(x[0], y[0]))
|
| 805 |
+
|
| 806 |
+
elif k == 1:
|
| 807 |
+
m[:, : l // 2] = shp_mask(shp, x[: l // 2], y, m[:, : l // 2])
|
| 808 |
+
m[:, l // 2 :] = shp_mask(shp, x[l // 2 :], y, m[:, l // 2 :])
|
| 809 |
+
|
| 810 |
+
elif l == 1:
|
| 811 |
+
m[: k // 2] = shp_mask(shp, x, y[: k // 2], m[: k // 2])
|
| 812 |
+
m[k // 2 :] = shp_mask(shp, x, y[k // 2 :], m[k // 2 :])
|
| 813 |
+
|
| 814 |
+
else:
|
| 815 |
+
m[: k // 2, : l // 2] = shp_mask(shp, x[: l // 2], y[: k // 2], m[: k // 2, : l // 2])
|
| 816 |
+
m[: k // 2, l // 2 :] = shp_mask(shp, x[l // 2 :], y[: k // 2], m[: k // 2, l // 2 :])
|
| 817 |
+
m[k // 2 :, : l // 2] = shp_mask(shp, x[: l // 2], y[k // 2 :], m[k // 2 :, : l // 2])
|
| 818 |
+
m[k // 2 :, l // 2 :] = shp_mask(shp, x[l // 2 :], y[k // 2 :], m[k // 2 :, l // 2 :])
|
| 819 |
+
|
| 820 |
+
return m
|
data_preprocess/cubicasa5k/run.sh
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# create COCO-style dataset for CubiCasa5k
|
| 2 |
+
python create_coco_cc5k.py --data_root=data/cubicasa5k/ \
|
| 3 |
+
--output=data/coco_cubicasa5k_nowalls_v4/ \
|
| 4 |
+
--disable_wd2line
|
| 5 |
+
|
| 6 |
+
# Split example has more than 1 floorplan into separate samples
|
| 7 |
+
python floorplan_extraction.py \
|
| 8 |
+
--data_root data/coco_cubicasa5k_nowalls_v4/ \
|
| 9 |
+
--output data/coco_cubicasa5k_nowalls_v4-1_refined/
|
| 10 |
+
|
| 11 |
+
# Merge individual JSONs into single JSON file per split (train/val/test)
|
| 12 |
+
# This must be done after floorplan_extraction.py
|
| 13 |
+
python combine_json.py \
|
| 14 |
+
--input data/coco_cubicasa5k_nowalls_v4-1_refined/ \
|
| 15 |
+
--output data/coco_cubicasa5k_nowalls_v4-1_refined/annotations/ \
|
data_preprocess/cubicasa5k/svg_utils.py
ADDED
|
@@ -0,0 +1,746 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from logging import warning
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
from skimage.draw import polygon
|
| 6 |
+
from svgpathtools import parse_path
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def get_room_number(e, rooms):
|
| 10 |
+
name_list = e.getAttribute("class").split(" ")
|
| 11 |
+
room_type = name_list[1]
|
| 12 |
+
try:
|
| 13 |
+
return rooms[room_type]
|
| 14 |
+
except KeyError:
|
| 15 |
+
warning("Room type " + e.getAttribute("class") + " not defined.")
|
| 16 |
+
return rooms["Undefined"]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def get_icon_number(e, icons):
|
| 20 |
+
name_list = e.getAttribute("class").split(" ")
|
| 21 |
+
icon_type = name_list[1]
|
| 22 |
+
|
| 23 |
+
try:
|
| 24 |
+
return icons[icon_type]
|
| 25 |
+
except KeyError:
|
| 26 |
+
warning("Icon type " + e.getAttribute("class") + " not defined.")
|
| 27 |
+
return icons["Misc"]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_icon(ee):
|
| 31 |
+
parent_transform = None
|
| 32 |
+
if ee.parentNode.getAttribute("class") == "FixedFurnitureSet":
|
| 33 |
+
parent_transform = ee.parentNode.getAttribute("transform")
|
| 34 |
+
strings = parent_transform.split(",")
|
| 35 |
+
a_p = float(strings[0][7:])
|
| 36 |
+
b_p = float(strings[1])
|
| 37 |
+
c_p = float(strings[2])
|
| 38 |
+
d_p = float(strings[3])
|
| 39 |
+
e_p = float(strings[-2])
|
| 40 |
+
f_p = float(strings[-1][:-1])
|
| 41 |
+
M_p = np.array([[a_p, c_p, e_p], [b_p, d_p, f_p], [0, 0, 1]])
|
| 42 |
+
|
| 43 |
+
transform = ee.getAttribute("transform")
|
| 44 |
+
strings = transform.split(",")
|
| 45 |
+
a = float(strings[0][7:])
|
| 46 |
+
b = float(strings[1])
|
| 47 |
+
c = float(strings[2])
|
| 48 |
+
d = float(strings[3])
|
| 49 |
+
e = float(strings[-2])
|
| 50 |
+
f = float(strings[-1][:-1])
|
| 51 |
+
|
| 52 |
+
M = np.array([[a, c, e], [b, d, f], [0, 0, 1]])
|
| 53 |
+
|
| 54 |
+
X = np.array([])
|
| 55 |
+
Y = np.array([])
|
| 56 |
+
|
| 57 |
+
try:
|
| 58 |
+
toilet = next(p for p in ee.childNodes if p.nodeName == "g" and p.getAttribute("class") == "BoundaryPolygon")
|
| 59 |
+
|
| 60 |
+
for p in toilet.childNodes:
|
| 61 |
+
if p.nodeName == "polygon":
|
| 62 |
+
X, Y = get_icon_polygon(p)
|
| 63 |
+
break
|
| 64 |
+
else:
|
| 65 |
+
x_all, y_all = get_corners(toilet)
|
| 66 |
+
points = np.column_stack((x_all, y_all))
|
| 67 |
+
|
| 68 |
+
X, Y = get_max_corners(points)
|
| 69 |
+
# if p.nodeName == "path":
|
| 70 |
+
# X, Y = get_icon_path(p)
|
| 71 |
+
|
| 72 |
+
except StopIteration:
|
| 73 |
+
X, Y = make_boudary_polygon(ee)
|
| 74 |
+
|
| 75 |
+
if len(X) < 4:
|
| 76 |
+
return None, None, X, Y
|
| 77 |
+
|
| 78 |
+
if parent_transform is not None:
|
| 79 |
+
for i in range(len(X)):
|
| 80 |
+
v = np.matrix([[X[i]], [Y[i]], [1]])
|
| 81 |
+
vv = np.matmul(M, v)
|
| 82 |
+
new_x, new_y, _ = np.round(np.matmul(M_p, vv))
|
| 83 |
+
X[i] = new_x
|
| 84 |
+
Y[i] = new_y
|
| 85 |
+
else:
|
| 86 |
+
for i in range(len(X)):
|
| 87 |
+
v = np.matrix([[X[i]], [Y[i]], [1]])
|
| 88 |
+
vv = np.matmul(M, v)
|
| 89 |
+
new_x, new_y, _ = np.round(vv)
|
| 90 |
+
X[i] = new_x
|
| 91 |
+
Y[i] = new_y
|
| 92 |
+
|
| 93 |
+
rr, cc = polygon(Y, X)
|
| 94 |
+
|
| 95 |
+
return rr, cc, X, Y
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def get_corners(g):
|
| 99 |
+
x_all, y_all = [], []
|
| 100 |
+
for pol in g.childNodes:
|
| 101 |
+
if pol.nodeName == "polygon":
|
| 102 |
+
x, y = get_icon_polygon(pol)
|
| 103 |
+
x_all = np.append(x_all, x)
|
| 104 |
+
y_all = np.append(y_all, y)
|
| 105 |
+
elif pol.nodeName == "path":
|
| 106 |
+
x, y = get_icon_path(pol)
|
| 107 |
+
x_all = np.append(x_all, x)
|
| 108 |
+
y_all = np.append(y_all, y)
|
| 109 |
+
elif pol.nodeName == "rect":
|
| 110 |
+
x = pol.getAttribute("x")
|
| 111 |
+
if x == "":
|
| 112 |
+
x = 1.0
|
| 113 |
+
else:
|
| 114 |
+
x = float(x)
|
| 115 |
+
y = pol.getAttribute("y")
|
| 116 |
+
if y == "":
|
| 117 |
+
y = 1.0
|
| 118 |
+
else:
|
| 119 |
+
y = float(y)
|
| 120 |
+
x_all = np.append(x_all, x)
|
| 121 |
+
y_all = np.append(y_all, y)
|
| 122 |
+
w = float(pol.getAttribute("width"))
|
| 123 |
+
h = float(pol.getAttribute("height"))
|
| 124 |
+
x_all = np.append(x_all, x + w)
|
| 125 |
+
y_all = np.append(y_all, y + h)
|
| 126 |
+
|
| 127 |
+
return x_all, y_all
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def get_max_corners(points):
|
| 131 |
+
if len(points) == 0:
|
| 132 |
+
return [], []
|
| 133 |
+
|
| 134 |
+
minx, miny = float("inf"), float("inf")
|
| 135 |
+
maxx, maxy = float("-inf"), float("-inf")
|
| 136 |
+
for x, y in points:
|
| 137 |
+
# Set min coords
|
| 138 |
+
if x < minx:
|
| 139 |
+
minx = x
|
| 140 |
+
if y < miny:
|
| 141 |
+
miny = y
|
| 142 |
+
# Set max coords
|
| 143 |
+
if x > maxx:
|
| 144 |
+
maxx = x
|
| 145 |
+
elif y > maxy:
|
| 146 |
+
maxy = y
|
| 147 |
+
|
| 148 |
+
X = np.array([minx, maxx, maxx, minx])
|
| 149 |
+
Y = np.array([miny, miny, maxy, maxy])
|
| 150 |
+
|
| 151 |
+
return X, Y
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def make_boudary_polygon(pol):
|
| 155 |
+
g_gen = (c for c in pol.childNodes if c.nodeName == "g")
|
| 156 |
+
|
| 157 |
+
x_all, y_all = [], []
|
| 158 |
+
for g in g_gen:
|
| 159 |
+
x, y = get_corners(g)
|
| 160 |
+
x_all = np.append(x_all, x)
|
| 161 |
+
y_all = np.append(y_all, y)
|
| 162 |
+
|
| 163 |
+
points = np.column_stack((x_all, y_all))
|
| 164 |
+
X, Y = get_max_corners(points)
|
| 165 |
+
|
| 166 |
+
return X, Y
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def get_icon_path(pol):
|
| 170 |
+
path = pol.getAttribute("d")
|
| 171 |
+
try:
|
| 172 |
+
path_alt = parse_path(path)
|
| 173 |
+
minx, maxx, miny, maxy = path_alt.bbox()
|
| 174 |
+
except ValueError as e:
|
| 175 |
+
print("Error handled")
|
| 176 |
+
print(e)
|
| 177 |
+
return np.array([]), np.array([])
|
| 178 |
+
|
| 179 |
+
X = np.array([minx, maxx, maxx, minx])
|
| 180 |
+
Y = np.array([miny, miny, maxy, maxy])
|
| 181 |
+
|
| 182 |
+
if np.unique(X).size == 1 or np.unique(Y).size == 1:
|
| 183 |
+
return np.array([]), np.array([])
|
| 184 |
+
|
| 185 |
+
return X, Y
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def get_icon_polygon(pol):
|
| 189 |
+
points = pol.getAttribute("points").split(" ")
|
| 190 |
+
|
| 191 |
+
return get_XY(points)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def get_XY(points):
|
| 195 |
+
if points[-1] == "":
|
| 196 |
+
points = points[:-1]
|
| 197 |
+
|
| 198 |
+
if points[0] == "":
|
| 199 |
+
points = points[1:]
|
| 200 |
+
|
| 201 |
+
X, Y = np.array([]), np.array([])
|
| 202 |
+
i = 0
|
| 203 |
+
for a in points:
|
| 204 |
+
if "," in a:
|
| 205 |
+
if len(a) == 2:
|
| 206 |
+
x, y = a.split(",")
|
| 207 |
+
else:
|
| 208 |
+
num_list = a.split(",")
|
| 209 |
+
x, y = num_list[0], num_list[1]
|
| 210 |
+
X = np.append(X, np.round(float(x)))
|
| 211 |
+
Y = np.append(Y, np.round(float(y)))
|
| 212 |
+
else:
|
| 213 |
+
# if no comma every other is x and every other is y
|
| 214 |
+
if i % 2:
|
| 215 |
+
Y = np.append(Y, float(a))
|
| 216 |
+
else:
|
| 217 |
+
X = np.append(X, float(a))
|
| 218 |
+
|
| 219 |
+
i += 1
|
| 220 |
+
|
| 221 |
+
return X, Y
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def get_points(e):
|
| 225 |
+
pol = next(p for p in e.childNodes if p.nodeName == "polygon")
|
| 226 |
+
points = pol.getAttribute("points").split(" ")
|
| 227 |
+
points = points[:-1]
|
| 228 |
+
|
| 229 |
+
X, Y = np.array([]), np.array([])
|
| 230 |
+
for a in points:
|
| 231 |
+
x, y = a.split(",")
|
| 232 |
+
X = np.append(X, np.round(float(x)))
|
| 233 |
+
Y = np.append(Y, np.round(float(y)))
|
| 234 |
+
|
| 235 |
+
return X, Y
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def get_direction(X, Y):
|
| 239 |
+
max_diff_X = abs(max(X) - min(X))
|
| 240 |
+
max_diff_Y = abs(max(Y) - min(Y))
|
| 241 |
+
|
| 242 |
+
if max_diff_X > max_diff_Y:
|
| 243 |
+
return "H" # horizontal
|
| 244 |
+
else:
|
| 245 |
+
return "V" # vertical
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def get_polygon(e):
|
| 249 |
+
pol = next(p for p in e.childNodes if p.nodeName == "polygon")
|
| 250 |
+
points = pol.getAttribute("points").split(" ")
|
| 251 |
+
points = points[:-1]
|
| 252 |
+
|
| 253 |
+
X, Y = np.array([]), np.array([])
|
| 254 |
+
for a in points:
|
| 255 |
+
y, x = a.split(",")
|
| 256 |
+
X = np.append(X, np.round(float(x)))
|
| 257 |
+
Y = np.append(Y, np.round(float(y)))
|
| 258 |
+
|
| 259 |
+
rr, cc = polygon(X, Y)
|
| 260 |
+
|
| 261 |
+
return rr, cc
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def calc_distance(point_1, point_2):
|
| 265 |
+
return math.sqrt(math.pow(point_1[0] - point_2[0], 2) + math.pow(point_1[1] - point_2[1], 2))
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def calc_center(points):
|
| 269 |
+
return list(np.mean(np.array(points), axis=0))
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def get_gaussian2D(ndim, sigma=0.25):
|
| 273 |
+
over_sigmau = 1.0 / (sigma * ndim)
|
| 274 |
+
over_sigmav = 1.0 / (sigma * ndim)
|
| 275 |
+
dst_data = np.zeros((ndim, ndim))
|
| 276 |
+
|
| 277 |
+
mean_u = 0.5 * ndim + 0.5
|
| 278 |
+
mean_v = 0.5 * ndim + 0.5
|
| 279 |
+
|
| 280 |
+
for v in range(ndim):
|
| 281 |
+
for u in range(ndim):
|
| 282 |
+
du = (u + 1 - mean_u) * over_sigmau
|
| 283 |
+
dv = (v + 1 - mean_v) * over_sigmav
|
| 284 |
+
value = np.exp(-0.5 * (du * du + dv * dv))
|
| 285 |
+
dst_data[v][u] = value
|
| 286 |
+
|
| 287 |
+
return dst_data
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def draw_junction(index, point, width, height, axes):
|
| 291 |
+
lineLength = 15
|
| 292 |
+
lineWidth = 7
|
| 293 |
+
x, y = point[0]
|
| 294 |
+
axes.text(x, y, str(index), fontsize=15, color="k")
|
| 295 |
+
###########################
|
| 296 |
+
# o
|
| 297 |
+
# | #6488ea soft blue
|
| 298 |
+
# | drawcode = [1,1]
|
| 299 |
+
#
|
| 300 |
+
###########################
|
| 301 |
+
if point[2][1] == 1 and point[2][2] == 1:
|
| 302 |
+
axes.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="#6488ea")
|
| 303 |
+
###########################
|
| 304 |
+
#
|
| 305 |
+
# ---o #6241c7 bluey purple
|
| 306 |
+
# drawcode = [1,2]
|
| 307 |
+
#
|
| 308 |
+
###########################
|
| 309 |
+
elif point[2][1] == 1 and point[2][2] == 2:
|
| 310 |
+
axes.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="#6241c7")
|
| 311 |
+
###########################
|
| 312 |
+
# |
|
| 313 |
+
# | drawcode = [1,3]
|
| 314 |
+
# o #056eee cerulean blue
|
| 315 |
+
#
|
| 316 |
+
###########################
|
| 317 |
+
elif point[2][1] == 1 and point[2][2] == 3:
|
| 318 |
+
axes.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="#056eee")
|
| 319 |
+
###########################
|
| 320 |
+
#
|
| 321 |
+
# drawcode = [1,4]
|
| 322 |
+
#
|
| 323 |
+
# o--- #004577 prussian blue
|
| 324 |
+
#
|
| 325 |
+
###########################
|
| 326 |
+
elif point[2][1] == 1 and point[2][2] == 4:
|
| 327 |
+
axes.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="#004577")
|
| 328 |
+
###########################
|
| 329 |
+
#
|
| 330 |
+
# |--- drawcode = [2,3]
|
| 331 |
+
# |
|
| 332 |
+
#
|
| 333 |
+
###########################
|
| 334 |
+
elif point[2][1] == 2 and point[2][2] == 3:
|
| 335 |
+
axes.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="#04d8b2")
|
| 336 |
+
axes.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="#04d8b2")
|
| 337 |
+
###########################
|
| 338 |
+
#
|
| 339 |
+
# ---|
|
| 340 |
+
# | drawcode = [2,4]
|
| 341 |
+
#
|
| 342 |
+
###########################
|
| 343 |
+
elif point[2][1] == 2 and point[2][2] == 4:
|
| 344 |
+
axes.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="#cdfd02")
|
| 345 |
+
axes.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="#cdfd02")
|
| 346 |
+
###########################
|
| 347 |
+
# |
|
| 348 |
+
# ---| drawcode = [2,1]
|
| 349 |
+
#
|
| 350 |
+
#
|
| 351 |
+
###########################
|
| 352 |
+
elif point[2][1] == 2 and point[2][2] == 1:
|
| 353 |
+
axes.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="#ff81c0")
|
| 354 |
+
axes.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="#ff81c0")
|
| 355 |
+
###########################
|
| 356 |
+
#
|
| 357 |
+
# |
|
| 358 |
+
# | drawcode = [2,2]
|
| 359 |
+
# --
|
| 360 |
+
#
|
| 361 |
+
###########################
|
| 362 |
+
elif point[2][1] == 2 and point[2][2] == 2:
|
| 363 |
+
axes.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="#f97306")
|
| 364 |
+
axes.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="#f97306")
|
| 365 |
+
###########################
|
| 366 |
+
#
|
| 367 |
+
# |
|
| 368 |
+
# |--- drawcode = [3,4]
|
| 369 |
+
# |
|
| 370 |
+
#
|
| 371 |
+
###########################
|
| 372 |
+
elif point[2][1] == 3 and point[2][2] == 4:
|
| 373 |
+
axes.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="b")
|
| 374 |
+
axes.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="b")
|
| 375 |
+
axes.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="b")
|
| 376 |
+
###########################
|
| 377 |
+
#
|
| 378 |
+
# ---
|
| 379 |
+
# | drawcode = [3,1]
|
| 380 |
+
# |
|
| 381 |
+
#
|
| 382 |
+
###########################
|
| 383 |
+
elif point[2][1] == 3 and point[2][2] == 1:
|
| 384 |
+
axes.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="y")
|
| 385 |
+
axes.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="y")
|
| 386 |
+
axes.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="y")
|
| 387 |
+
###########################
|
| 388 |
+
#
|
| 389 |
+
# |
|
| 390 |
+
# ---| drawcode = [3,2]
|
| 391 |
+
# |
|
| 392 |
+
#
|
| 393 |
+
###########################
|
| 394 |
+
elif point[2][1] == 3 and point[2][2] == 2:
|
| 395 |
+
axes.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="r")
|
| 396 |
+
axes.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="r")
|
| 397 |
+
axes.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="r")
|
| 398 |
+
###########################
|
| 399 |
+
#
|
| 400 |
+
# |
|
| 401 |
+
# | drawcode = [3,3]
|
| 402 |
+
# ---
|
| 403 |
+
#
|
| 404 |
+
###########################
|
| 405 |
+
elif point[2][1] == 3 and point[2][2] == 3:
|
| 406 |
+
axes.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="m")
|
| 407 |
+
axes.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="m")
|
| 408 |
+
axes.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="m")
|
| 409 |
+
###########################
|
| 410 |
+
#
|
| 411 |
+
# |
|
| 412 |
+
# --- drawcode = [4,1]
|
| 413 |
+
# |
|
| 414 |
+
#
|
| 415 |
+
###########################
|
| 416 |
+
elif point[2][1] == 4 and point[2][2] == 1:
|
| 417 |
+
axes.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="k")
|
| 418 |
+
axes.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="k")
|
| 419 |
+
axes.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="k")
|
| 420 |
+
axes.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="k")
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
class Wall:
|
| 424 |
+
def __init__(self, id, end_points, direction, width, name):
|
| 425 |
+
self.id = id
|
| 426 |
+
self.name = name
|
| 427 |
+
self.end_points = end_points
|
| 428 |
+
self.direction = direction
|
| 429 |
+
self.max_width = width
|
| 430 |
+
self.min_width = width
|
| 431 |
+
|
| 432 |
+
def change_end_points(self):
|
| 433 |
+
if self.direction == "V":
|
| 434 |
+
self.end_points[0][0] = np.mean(np.array(self.min_coord))
|
| 435 |
+
self.end_points[1][0] = self.end_points[0][0]
|
| 436 |
+
elif self.direction == "H":
|
| 437 |
+
self.end_points[0][1] = np.mean(np.array(self.min_coord))
|
| 438 |
+
self.end_points[1][1] = self.end_points[0][1]
|
| 439 |
+
|
| 440 |
+
def get_length(self, end_points):
|
| 441 |
+
return calc_distance(end_points[0], end_points[1])
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
class LineWall(Wall):
|
| 445 |
+
def __init__(self, id, end_points, direction, width, name):
|
| 446 |
+
Wall.__init__(self, id, end_points, direction, width, name)
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
class PolygonWall(Wall):
|
| 450 |
+
def __init__(self, e, id, shape=None):
|
| 451 |
+
self.id = id
|
| 452 |
+
self.name = e.getAttribute("id")
|
| 453 |
+
self.X, self.Y = self.get_points(e)
|
| 454 |
+
if abs(max(self.X) - min(self.X)) < 4 or abs(max(self.Y) - min(self.Y)) < 4:
|
| 455 |
+
# wall is too small and we ignore it.
|
| 456 |
+
raise ValueError("small wall")
|
| 457 |
+
if shape:
|
| 458 |
+
self.X = np.clip(self.X, 0, shape[1])
|
| 459 |
+
self.Y = np.clip(self.Y, 0, shape[0])
|
| 460 |
+
# self.X, self.Y = self.sort_X_Y(self.X, self.Y)
|
| 461 |
+
self.rr, self.cc = polygon(self.Y, self.X)
|
| 462 |
+
direction = self.get_direction(self.X, self.Y)
|
| 463 |
+
end_points = self.get_end_points(self.X, self.Y, direction)
|
| 464 |
+
self.min_width = self.get_width(self.X, self.Y, direction)
|
| 465 |
+
self.max_width = self.min_width
|
| 466 |
+
|
| 467 |
+
Wall.__init__(self, id, end_points, direction, self.max_width, self.name)
|
| 468 |
+
self.length = self.get_length(self.end_points)
|
| 469 |
+
self.center = self.get_center(self.X, self.Y)
|
| 470 |
+
self.min_coord, self.max_coord = self.get_width_coods(self.X, self.Y)
|
| 471 |
+
|
| 472 |
+
def get_points(self, e):
|
| 473 |
+
pol = next(p for p in e.childNodes if p.nodeName == "polygon")
|
| 474 |
+
points = pol.getAttribute("points").split(" ")
|
| 475 |
+
points = points[:-1]
|
| 476 |
+
|
| 477 |
+
X, Y = np.array([]), np.array([])
|
| 478 |
+
for a in points:
|
| 479 |
+
x, y = a.split(",")
|
| 480 |
+
X = np.append(X, np.round(float(x)))
|
| 481 |
+
Y = np.append(Y, np.round(float(y)))
|
| 482 |
+
|
| 483 |
+
return X, Y
|
| 484 |
+
|
| 485 |
+
def get_direction(self, X, Y):
|
| 486 |
+
max_diff_X = abs(max(X) - min(X))
|
| 487 |
+
max_diff_Y = abs(max(Y) - min(Y))
|
| 488 |
+
|
| 489 |
+
if max_diff_X > max_diff_Y:
|
| 490 |
+
return "H" # horizontal
|
| 491 |
+
else:
|
| 492 |
+
return "V" # vertical
|
| 493 |
+
|
| 494 |
+
def get_center(self, X, Y):
|
| 495 |
+
return np.mean(X), np.mean(Y)
|
| 496 |
+
|
| 497 |
+
def get_width(self, X, Y, direction):
|
| 498 |
+
_, _, p1, p2 = self._get_min_points(X, Y)
|
| 499 |
+
|
| 500 |
+
if direction == "H":
|
| 501 |
+
return (abs(p1[0][1] - p1[1][1]) + abs(p2[0][1] - p2[1][1])) / 2
|
| 502 |
+
elif "V":
|
| 503 |
+
return (abs(p1[0][0] - p1[1][0]) + abs(p2[0][0] - p2[1][0])) / 2
|
| 504 |
+
|
| 505 |
+
def _width(self, values):
|
| 506 |
+
temp = values.tolist() if type(values) is not list else values
|
| 507 |
+
|
| 508 |
+
mean_1 = min(temp)
|
| 509 |
+
mean_2 = max(temp)
|
| 510 |
+
|
| 511 |
+
return abs(mean_1 - mean_2)
|
| 512 |
+
|
| 513 |
+
def merge_possible(self, merged):
|
| 514 |
+
max_dist = max([self.max_width, merged.max_width])
|
| 515 |
+
|
| 516 |
+
if self.id == merged.id:
|
| 517 |
+
return False
|
| 518 |
+
|
| 519 |
+
# walls have to be in the same direction
|
| 520 |
+
if self.direction != merged.direction:
|
| 521 |
+
return False
|
| 522 |
+
|
| 523 |
+
# walls have too big width difference
|
| 524 |
+
if abs(self.max_width - merged.max_width) > merged.max_width:
|
| 525 |
+
return False
|
| 526 |
+
|
| 527 |
+
# If endpoints are near
|
| 528 |
+
# self up and left endpoint to merged down and right end point
|
| 529 |
+
dist1 = calc_distance(self.end_points[0], merged.end_points[1])
|
| 530 |
+
# self down and right endpoint to merged up and left end point
|
| 531 |
+
dist2 = calc_distance(self.end_points[1], merged.end_points[0])
|
| 532 |
+
|
| 533 |
+
if dist1 <= max_dist * 1.5 or dist2 <= max_dist * 1.5:
|
| 534 |
+
return True
|
| 535 |
+
else:
|
| 536 |
+
return False
|
| 537 |
+
|
| 538 |
+
def _get_overlap(self, a, b):
|
| 539 |
+
return max(0, min(a[1], b[1]) - max(a[0], b[0]))
|
| 540 |
+
|
| 541 |
+
def merge_walls(self, merged):
|
| 542 |
+
max_dist = max([self.max_width, merged.max_width])
|
| 543 |
+
|
| 544 |
+
if self.id == merged.id:
|
| 545 |
+
return None
|
| 546 |
+
|
| 547 |
+
# walls have to be in the same direction
|
| 548 |
+
if self.direction != merged.direction:
|
| 549 |
+
return None
|
| 550 |
+
|
| 551 |
+
# If endpoints are near
|
| 552 |
+
# self up and left endpoint to merged down and right end point
|
| 553 |
+
dist1 = calc_distance(self.end_points[0], merged.end_points[1])
|
| 554 |
+
# self down and right endpoint to merged up and left end point
|
| 555 |
+
dist2 = calc_distance(self.end_points[1], merged.end_points[0])
|
| 556 |
+
|
| 557 |
+
if dist1 <= max_dist * 1.5:
|
| 558 |
+
if self._get_overlap(self.min_coord, merged.min_coord) <= 0:
|
| 559 |
+
return None
|
| 560 |
+
# merged is on top or on left
|
| 561 |
+
return self.do_merge(merged, 0)
|
| 562 |
+
elif dist2 <= max_dist * 1.5:
|
| 563 |
+
if self._get_overlap(self.min_coord, merged.min_coord) <= 0:
|
| 564 |
+
return None
|
| 565 |
+
# merged is on down or on right
|
| 566 |
+
return self.do_merge(merged, 1)
|
| 567 |
+
else:
|
| 568 |
+
return None
|
| 569 |
+
|
| 570 |
+
def _get_min_points(self, X, Y):
|
| 571 |
+
assert len(X) == 4 and len(Y) == 4
|
| 572 |
+
length = len(X)
|
| 573 |
+
min_dist1 = np.inf
|
| 574 |
+
min_dist2 = np.inf
|
| 575 |
+
point1 = None
|
| 576 |
+
point2 = None
|
| 577 |
+
corners1 = None
|
| 578 |
+
corners2 = None
|
| 579 |
+
|
| 580 |
+
for i in range(length):
|
| 581 |
+
x1, y1 = X[i], Y[i]
|
| 582 |
+
x2, y2 = X[(i + 1) % 4], Y[(i + 1) % 4]
|
| 583 |
+
|
| 584 |
+
dist = np.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)
|
| 585 |
+
if dist < min_dist1:
|
| 586 |
+
point2 = point1
|
| 587 |
+
point1 = np.array([(x1 + x2) / 2, (y1 + y2) / 2])
|
| 588 |
+
min_dist2 = min_dist1
|
| 589 |
+
min_dist1 = dist
|
| 590 |
+
corners2 = corners1
|
| 591 |
+
corners1 = np.array([[x1, y1], [x2, y2]])
|
| 592 |
+
elif dist <= min_dist2:
|
| 593 |
+
point2 = np.array([(x1 + x2) / 2, (y1 + y2) / 2])
|
| 594 |
+
min_dist2 = dist
|
| 595 |
+
corners2 = np.array([[x1, y1], [x2, y2]])
|
| 596 |
+
|
| 597 |
+
return point1, point2, corners1, corners2
|
| 598 |
+
|
| 599 |
+
def get_end_points(self, X, Y, direction):
|
| 600 |
+
point1, point2, _, _ = self._get_min_points(X, Y)
|
| 601 |
+
|
| 602 |
+
if point1[0] != point2[0] or point1[1] != point2[1]:
|
| 603 |
+
if abs(point1[0] - point2[0]) > abs(point1[1] - point2[1]):
|
| 604 |
+
# horizontal
|
| 605 |
+
point1[1] = point1[1] + point2[1] / 2.0
|
| 606 |
+
point2[1] = point1[1]
|
| 607 |
+
# point1[1] = int(np.round(point1[1]))
|
| 608 |
+
# point2[1] = int(np.round(point2[1]))
|
| 609 |
+
else:
|
| 610 |
+
# vertical
|
| 611 |
+
point1[0] = point1[0] + point2[0] / 2.0
|
| 612 |
+
point2[0] = point1[0]
|
| 613 |
+
# point1[0] = int(np.round(point1[0]))
|
| 614 |
+
# point2[0] = int(np.round(point2[0]))
|
| 615 |
+
|
| 616 |
+
return self.sort_end_points(direction, point1, point2)
|
| 617 |
+
|
| 618 |
+
def sort_end_points(self, direction, point1, point2):
|
| 619 |
+
if direction == "V":
|
| 620 |
+
if point1[1] < point2[1]:
|
| 621 |
+
return np.array([point1, point2])
|
| 622 |
+
else:
|
| 623 |
+
return np.array([point2, point1])
|
| 624 |
+
else:
|
| 625 |
+
if point1[0] < point2[0]:
|
| 626 |
+
return np.array([point1, point2])
|
| 627 |
+
else:
|
| 628 |
+
return np.array([point2, point1])
|
| 629 |
+
|
| 630 |
+
def do_merge(self, merged, direction):
|
| 631 |
+
# update width
|
| 632 |
+
self.max_width = max([self.max_width, merged.max_width])
|
| 633 |
+
self.min_width = min([self.min_width, merged.min_width])
|
| 634 |
+
|
| 635 |
+
# update polygon
|
| 636 |
+
self.X = np.concatenate((self.X, merged.X))
|
| 637 |
+
self.Y = np.concatenate((self.Y, merged.Y))
|
| 638 |
+
|
| 639 |
+
# update width coordinates
|
| 640 |
+
self.max_coord = self.get_max_width_coord(merged)
|
| 641 |
+
self.min_coord = self.get_min_width_coord(merged)
|
| 642 |
+
|
| 643 |
+
if direction == 0:
|
| 644 |
+
# merged wall is up or left to the original wall
|
| 645 |
+
self.end_points = np.array([merged.end_points[0], self.end_points[1]])
|
| 646 |
+
else:
|
| 647 |
+
# merged wall is down or right to the original wall
|
| 648 |
+
self.end_points = np.array([self.end_points[0], merged.end_points[1]])
|
| 649 |
+
|
| 650 |
+
self.length = self.get_length(self.end_points)
|
| 651 |
+
|
| 652 |
+
return self
|
| 653 |
+
|
| 654 |
+
def get_max_width_coord(self, merged):
|
| 655 |
+
width_1 = abs(self.max_coord[0] - self.max_coord[1])
|
| 656 |
+
width_2 = abs(merged.max_coord[0] - merged.max_coord[1])
|
| 657 |
+
return self.max_coord if width_1 > width_2 else merged.max_coord
|
| 658 |
+
|
| 659 |
+
def get_min_width_coord(self, merged):
|
| 660 |
+
width_1 = max(merged.min_coord[0], self.min_coord[0])
|
| 661 |
+
# width_1 = abs(self.min_coord[0] - self.min_coord[1])
|
| 662 |
+
width_2 = min(merged.min_coord[1], self.min_coord[1])
|
| 663 |
+
# width_2 = abs(merged.min_coord[0] - merged.min_coord[1])
|
| 664 |
+
# return self.min_coord if width_1 < width_2 else merged.min_coord
|
| 665 |
+
return [width_1, width_2]
|
| 666 |
+
|
| 667 |
+
def get_width_coods(self, X, Y):
|
| 668 |
+
if self.direction == "H":
|
| 669 |
+
dist_1 = abs(Y[0] - Y[2])
|
| 670 |
+
dist_2 = abs(Y[1] - Y[3])
|
| 671 |
+
if dist_1 < dist_2:
|
| 672 |
+
return [Y[0], Y[2]], [Y[1], Y[3]]
|
| 673 |
+
else:
|
| 674 |
+
return [Y[1], Y[3]], [Y[0], Y[2]]
|
| 675 |
+
|
| 676 |
+
elif self.direction == "V":
|
| 677 |
+
dist_1 = abs(X[0] - X[3])
|
| 678 |
+
dist_2 = abs(X[1] - X[2])
|
| 679 |
+
if dist_1 < dist_2:
|
| 680 |
+
return [X[0], X[3]], [X[1], X[2]]
|
| 681 |
+
else:
|
| 682 |
+
return [X[1], X[2]], [X[0], X[3]]
|
| 683 |
+
|
| 684 |
+
def sort_X_Y(self, X, Y):
|
| 685 |
+
max_x = max(X)
|
| 686 |
+
min_x = min(X)
|
| 687 |
+
max_y = max(Y)
|
| 688 |
+
min_y = min(Y)
|
| 689 |
+
res_X, res_Y = [0] * 4, [0] * 4
|
| 690 |
+
# top left 0, top right 1, bottom left 2, bottom right 3
|
| 691 |
+
directions = [[min_x, min_y], [max_x, min_y], [min_x, max_y], [max_x, max_y]]
|
| 692 |
+
length = len(X)
|
| 693 |
+
for i in range(length):
|
| 694 |
+
min_dist = 1000000
|
| 695 |
+
direction_candidate = None
|
| 696 |
+
for j, direc in enumerate(directions):
|
| 697 |
+
dist = calc_distance([X[i], Y[i]], direc)
|
| 698 |
+
if dist < min_dist:
|
| 699 |
+
min_dist = dist
|
| 700 |
+
direction_candidate = j
|
| 701 |
+
|
| 702 |
+
res_X[direction_candidate] = X[i]
|
| 703 |
+
res_Y[direction_candidate] = Y[i]
|
| 704 |
+
|
| 705 |
+
return res_X, res_Y
|
| 706 |
+
|
| 707 |
+
def wall_is_pillar(self, avg_wall_width):
|
| 708 |
+
if self.max_width > avg_wall_width:
|
| 709 |
+
if self.length < 3 * self.max_width:
|
| 710 |
+
return True
|
| 711 |
+
|
| 712 |
+
return False
|
| 713 |
+
|
| 714 |
+
def split_pillar_wall(self, ids, avg_wall_width):
|
| 715 |
+
half = avg_wall_width / 3.0
|
| 716 |
+
end_points = [[[0, 0], [0, 0]], [[0, 0], [0, 0]], [[0, 0], [0, 0]], [[0, 0], [0, 0]]]
|
| 717 |
+
self.X[np.argmax(self.X)] = max(self.X) - half
|
| 718 |
+
self.X[np.argmax(self.X)] = max(self.X) - half
|
| 719 |
+
self.X[np.argmin(self.X)] = min(self.X) + half
|
| 720 |
+
self.X[np.argmin(self.X)] = min(self.X) + half
|
| 721 |
+
self.Y[np.argmax(self.Y)] = max(self.Y) - half
|
| 722 |
+
self.Y[np.argmax(self.Y)] = max(self.Y) - half
|
| 723 |
+
self.Y[np.argmin(self.Y)] = min(self.Y) + half
|
| 724 |
+
self.Y[np.argmin(self.Y)] = min(self.Y) + half
|
| 725 |
+
for i in range(4):
|
| 726 |
+
x = self.X[i]
|
| 727 |
+
y = self.Y[i]
|
| 728 |
+
end = [x, y]
|
| 729 |
+
j = i % 2
|
| 730 |
+
end_points[i][j] = end
|
| 731 |
+
end_points[(i + 3) % 4][j] = end
|
| 732 |
+
|
| 733 |
+
walls = []
|
| 734 |
+
for i, e in enumerate(end_points):
|
| 735 |
+
if abs(e[0][1] - e[1][1]) > abs(e[0][0] - e[1][0]):
|
| 736 |
+
# vertical wall
|
| 737 |
+
direction = "V"
|
| 738 |
+
else:
|
| 739 |
+
# horizontal wall
|
| 740 |
+
direction = "H"
|
| 741 |
+
|
| 742 |
+
e = self.sort_end_points(direction, e[0], e[1])
|
| 743 |
+
wall = LineWall(ids + i, e, direction, avg_wall_width / 2.0, self.name)
|
| 744 |
+
walls.append(wall)
|
| 745 |
+
|
| 746 |
+
return walls
|
data_preprocess/raster2graph/combine_json.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import shutil
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def combine_json_files(input_pattern, data_path, split_type, output_file, output_image_dir, start_image_id=0):
|
| 9 |
+
"""
|
| 10 |
+
Combines multiple COCO-style JSON annotation files into a single file.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
input_pattern: Glob pattern to match the input JSON files (e.g., "annotations/*.json")
|
| 14 |
+
output_file: Path to the output combined JSON file
|
| 15 |
+
"""
|
| 16 |
+
os.makedirs(output_image_dir, exist_ok=True)
|
| 17 |
+
|
| 18 |
+
# Initialize combined data structure
|
| 19 |
+
combined_data = {"images": [], "annotations": [], "categories": []}
|
| 20 |
+
|
| 21 |
+
# Track image and annotation IDs to avoid duplicates
|
| 22 |
+
annotation_ids_seen = set()
|
| 23 |
+
|
| 24 |
+
next_image_id = start_image_id
|
| 25 |
+
next_annotation_id = 0
|
| 26 |
+
skip_file_list = []
|
| 27 |
+
image_id_mapping = {}
|
| 28 |
+
|
| 29 |
+
# Find all matching JSON files
|
| 30 |
+
json_files = sorted(glob.glob(input_pattern))
|
| 31 |
+
print(f"Found {len(json_files)} JSON files to combine")
|
| 32 |
+
|
| 33 |
+
# Process each file
|
| 34 |
+
for i, json_file in enumerate(json_files):
|
| 35 |
+
print(f"Processing file {i + 1}/{len(json_files)}: {json_file}")
|
| 36 |
+
|
| 37 |
+
with open(json_file, "r") as f:
|
| 38 |
+
data = json.load(f)
|
| 39 |
+
|
| 40 |
+
# Store categories from the first file
|
| 41 |
+
if i == 0 and data.get("categories"):
|
| 42 |
+
combined_data["categories"] = data["categories"]
|
| 43 |
+
|
| 44 |
+
# empty annos
|
| 45 |
+
if len(data["annotations"]) == 0:
|
| 46 |
+
skip_file_list.append(data["images"][0]["id"])
|
| 47 |
+
continue
|
| 48 |
+
|
| 49 |
+
# Process images
|
| 50 |
+
for image in data.get("images", []):
|
| 51 |
+
if image["id"] not in image_id_mapping:
|
| 52 |
+
image_id_mapping[image["id"]] = next_image_id
|
| 53 |
+
else:
|
| 54 |
+
skip_file_list.append(image["id"])
|
| 55 |
+
continue
|
| 56 |
+
image["id"] = next_image_id
|
| 57 |
+
next_image_id += 1
|
| 58 |
+
# org_file_name = copy(image['file_name'])
|
| 59 |
+
image["file_name"] = str(image["id"]).zfill(6) + ".png"
|
| 60 |
+
org_file_name = os.path.basename(json_file).replace(".json", ".png")
|
| 61 |
+
if image["file_name"] != org_file_name and os.path.exists(f"{data_path}/{split_type}/{org_file_name}"):
|
| 62 |
+
shutil.copy(f"{data_path}/{split_type}/{org_file_name}", f"{output_image_dir}/{image['file_name']}")
|
| 63 |
+
combined_data["images"].append(image)
|
| 64 |
+
|
| 65 |
+
# Process annotations
|
| 66 |
+
for annotation in data.get("annotations", []):
|
| 67 |
+
annotation["id"] = next_annotation_id
|
| 68 |
+
next_annotation_id += 1
|
| 69 |
+
annotation["image_id"] = image_id_mapping[annotation["image_id"]]
|
| 70 |
+
|
| 71 |
+
annotation_ids_seen.add(annotation["id"])
|
| 72 |
+
combined_data["annotations"].append(annotation)
|
| 73 |
+
|
| 74 |
+
# Write combined data to output file
|
| 75 |
+
output_path = Path(output_file)
|
| 76 |
+
output_path.parent.mkdir(exist_ok=True, parents=True)
|
| 77 |
+
|
| 78 |
+
with open(output_file, "w") as f:
|
| 79 |
+
json.dump(combined_data, f, indent=2)
|
| 80 |
+
|
| 81 |
+
with open(output_path.parent / f"{output_path.name.split('.')[0]}_image_id_mapping.json", "w") as f:
|
| 82 |
+
json.dump(image_id_mapping, f, indent=2)
|
| 83 |
+
|
| 84 |
+
if len(skip_file_list):
|
| 85 |
+
with open(output_path.parent / f"{output_path.name.split('.')[0]}_skipped.txt", "w") as f:
|
| 86 |
+
f.write("\n".join([str(x) for x in skip_file_list]))
|
| 87 |
+
|
| 88 |
+
print(f"Combined data written to {output_file}")
|
| 89 |
+
print(f"Total images: {len(combined_data['images'])}")
|
| 90 |
+
print(f"Total annotations: {len(combined_data['annotations'])}")
|
| 91 |
+
print(f"Total categories: {len(combined_data['categories'])}")
|
| 92 |
+
print(f"Skipped images: {len(skip_file_list)}")
|
| 93 |
+
|
| 94 |
+
image_id_mapping_list = [[f"{k} {v}"] for k, v in image_id_mapping.items()] # Reverse mapping for easier lookup
|
| 95 |
+
|
| 96 |
+
return combined_data, image_id_mapping_list
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
if __name__ == "__main__":
|
| 100 |
+
import argparse
|
| 101 |
+
|
| 102 |
+
parser = argparse.ArgumentParser(description="Combine multiple COCO-style JSON annotation files")
|
| 103 |
+
parser.add_argument("--input", required=True, help="Glob pattern for input JSON files, e.g., 'annotations/*.json'")
|
| 104 |
+
parser.add_argument("--output", required=True, help="Output JSON file path")
|
| 105 |
+
|
| 106 |
+
args = parser.parse_args()
|
| 107 |
+
|
| 108 |
+
splits = ["train", "val", "test"]
|
| 109 |
+
for i, split in enumerate(splits):
|
| 110 |
+
if split == "train":
|
| 111 |
+
start_image_id = 0
|
| 112 |
+
else:
|
| 113 |
+
start_image_id += len(list(Path(f"{args.input}/{splits[i - 1]}").glob("*.png")))
|
| 114 |
+
|
| 115 |
+
_, image_id_mapping_list = combine_json_files(
|
| 116 |
+
f"{args.input}/{split}_jsons/*.json",
|
| 117 |
+
args.input,
|
| 118 |
+
split,
|
| 119 |
+
f"{args.output}/annotations/{split}.json",
|
| 120 |
+
output_image_dir=f"{args.output}/{split}",
|
| 121 |
+
start_image_id=start_image_id,
|
| 122 |
+
)
|
data_preprocess/raster2graph/combine_mapping_ids.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def generate_combined_mapping(file_mapping_path, image_id_mapping_path, output_path):
|
| 5 |
+
"""
|
| 6 |
+
Generates a combined mapping file from an original filename mapping
|
| 7 |
+
and an image ID mapping.
|
| 8 |
+
|
| 9 |
+
Args:
|
| 10 |
+
file_mapping_path (str): Path to the text file mapping original filenames
|
| 11 |
+
to intermediate 6-digit IDs.
|
| 12 |
+
image_id_mapping_path (str): Path to the JSON file mapping intermediate
|
| 13 |
+
IDs to destination IDs.
|
| 14 |
+
output_path (str): Path where the new combined mapping file will be saved.
|
| 15 |
+
"""
|
| 16 |
+
# 1. Read test_file_mapping.txt
|
| 17 |
+
org_fn_to_intermediate_id = {}
|
| 18 |
+
try:
|
| 19 |
+
with open(file_mapping_path, "r") as f:
|
| 20 |
+
for line in f:
|
| 21 |
+
parts = line.strip().split()
|
| 22 |
+
if len(parts) == 2:
|
| 23 |
+
org_fn = parts[0]
|
| 24 |
+
# Convert the 6-digit string ID to an integer for lookup
|
| 25 |
+
intermediate_id_str = parts[1]
|
| 26 |
+
# Remove leading zeros and convert to int
|
| 27 |
+
intermediate_id = int(intermediate_id_str)
|
| 28 |
+
org_fn_to_intermediate_id[org_fn] = intermediate_id
|
| 29 |
+
except FileNotFoundError:
|
| 30 |
+
print(f"Error: The file '{file_mapping_path}' was not found.")
|
| 31 |
+
return
|
| 32 |
+
except Exception as e:
|
| 33 |
+
print(f"Error reading '{file_mapping_path}': {e}")
|
| 34 |
+
return
|
| 35 |
+
|
| 36 |
+
# 2. Read test_image_id_mapping.json
|
| 37 |
+
intermediate_id_to_dst_fn = {}
|
| 38 |
+
try:
|
| 39 |
+
with open(image_id_mapping_path, "r") as f:
|
| 40 |
+
image_id_data = json.load(f)
|
| 41 |
+
for key, value in image_id_data.items():
|
| 42 |
+
# Keys in JSON are strings, convert to int for consistency
|
| 43 |
+
intermediate_id_to_dst_fn[int(key)] = value
|
| 44 |
+
except FileNotFoundError:
|
| 45 |
+
print(f"Error: The file '{image_id_mapping_path}' was not found.")
|
| 46 |
+
return
|
| 47 |
+
except json.JSONDecodeError:
|
| 48 |
+
print(f"Error: Could not decode JSON from '{image_id_mapping_path}'. Please ensure it's valid JSON.")
|
| 49 |
+
return
|
| 50 |
+
except Exception as e:
|
| 51 |
+
print(f"Error reading '{image_id_mapping_path}': {e}")
|
| 52 |
+
return
|
| 53 |
+
|
| 54 |
+
# 3. Create the combined mapping and write to output file
|
| 55 |
+
combined_mappings = []
|
| 56 |
+
found_mappings_count = 0
|
| 57 |
+
for org_fn, intermediate_id in org_fn_to_intermediate_id.items():
|
| 58 |
+
if intermediate_id in intermediate_id_to_dst_fn:
|
| 59 |
+
dst_fn = intermediate_id_to_dst_fn[intermediate_id]
|
| 60 |
+
combined_mappings.append(f"{org_fn} {dst_fn}")
|
| 61 |
+
found_mappings_count += 1
|
| 62 |
+
else:
|
| 63 |
+
# Optionally, you can print a warning for IDs not found
|
| 64 |
+
print(f"Warning: Intermediate ID '{intermediate_id}' for '{org_fn}' not found in image ID mapping.")
|
| 65 |
+
|
| 66 |
+
try:
|
| 67 |
+
with open(output_path, "w") as f:
|
| 68 |
+
for mapping_line in combined_mappings:
|
| 69 |
+
f.write(mapping_line + "\n")
|
| 70 |
+
print(f"\nSuccessfully generated combined mapping to '{output_path}'.")
|
| 71 |
+
print(f"Total original filenames processed: {len(org_fn_to_intermediate_id)}")
|
| 72 |
+
print(f"Total combined mappings written: {found_mappings_count}")
|
| 73 |
+
except Exception as e:
|
| 74 |
+
print(f"Error writing to output file '{output_path}': {e}")
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# Define file paths
|
| 78 |
+
file_mapping_path = "data/R2G_hr_dataset_processed/test_file_mapping.txt"
|
| 79 |
+
image_id_mapping_path = "data/R2G_hr_dataset_processed_v1/annotations/test_image_id_mapping.json"
|
| 80 |
+
output_mapping_path = "data/R2G_hr_dataset_processed_v1/annotations/test_combined_mapping.txt"
|
| 81 |
+
|
| 82 |
+
# Run the mapping function
|
| 83 |
+
generate_combined_mapping(file_mapping_path, image_id_mapping_path, output_mapping_path)
|
| 84 |
+
|
| 85 |
+
# You can optionally print the content of the generated file to verify
|
| 86 |
+
print("\n--- Content of combined_mapping.txt ---")
|
| 87 |
+
try:
|
| 88 |
+
with open(output_mapping_path, "r") as f:
|
| 89 |
+
print(f.read())
|
| 90 |
+
except FileNotFoundError:
|
| 91 |
+
print("Output file was not created.")
|
| 92 |
+
|
| 93 |
+
# Clean up dummy files (optional)
|
| 94 |
+
# os.remove(file_mapping_path)
|
| 95 |
+
# os.remove(image_id_mapping_path)
|
data_preprocess/raster2graph/convert_to_coco.py
ADDED
|
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gc
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
|
| 5 |
+
print(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 6 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 7 |
+
|
| 8 |
+
import argparse
|
| 9 |
+
import json
|
| 10 |
+
import shutil
|
| 11 |
+
from multiprocessing import Pool
|
| 12 |
+
|
| 13 |
+
import cv2
|
| 14 |
+
import matplotlib.pyplot as plt
|
| 15 |
+
import numpy as np
|
| 16 |
+
import torch
|
| 17 |
+
from datasets.dataset import MyDataset
|
| 18 |
+
from matplotlib.patches import Patch
|
| 19 |
+
from shapely.geometry import Polygon
|
| 20 |
+
from tqdm import tqdm
|
| 21 |
+
from util.data_utils import edge_inside
|
| 22 |
+
from util.graph_utils import get_cycle_basis_and_semantic, tensors_to_graphs_batch
|
| 23 |
+
|
| 24 |
+
mean = [0.920, 0.913, 0.891]
|
| 25 |
+
std = [0.214, 0.216, 0.228]
|
| 26 |
+
|
| 27 |
+
ID2CLASS = {
|
| 28 |
+
0: "unknown",
|
| 29 |
+
1: "living_room",
|
| 30 |
+
2: "kitchen",
|
| 31 |
+
3: "bedroom",
|
| 32 |
+
4: "bathroom",
|
| 33 |
+
5: "restroom",
|
| 34 |
+
6: "balcony",
|
| 35 |
+
7: "closet",
|
| 36 |
+
8: "corridor",
|
| 37 |
+
9: "washing_room",
|
| 38 |
+
10: "PS",
|
| 39 |
+
11: "outside",
|
| 40 |
+
# 12: 'wall'
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def plot_room_map(preds, room_map, room_id=0, im_size=256, plot_text=True):
|
| 45 |
+
"""Draw room polygons overlaid on the density map"""
|
| 46 |
+
centroid_x = int(np.mean(preds[:, 0]))
|
| 47 |
+
centroid_y = int(np.mean(preds[:, 1]))
|
| 48 |
+
|
| 49 |
+
# Get text size to create a background box
|
| 50 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
| 51 |
+
font_scale = 0.3
|
| 52 |
+
thickness = 1
|
| 53 |
+
text = str(room_id)
|
| 54 |
+
(text_width, text_height), baseline = cv2.getTextSize(text, font, font_scale, thickness)
|
| 55 |
+
border_color = (252, 252, 0)
|
| 56 |
+
|
| 57 |
+
for i, corner in enumerate(preds):
|
| 58 |
+
if i == len(preds) - 1:
|
| 59 |
+
cv2.line(
|
| 60 |
+
room_map,
|
| 61 |
+
(round(corner[0]), round(corner[1])),
|
| 62 |
+
(round(preds[0][0]), round(preds[0][1])),
|
| 63 |
+
border_color,
|
| 64 |
+
2,
|
| 65 |
+
)
|
| 66 |
+
else:
|
| 67 |
+
cv2.line(
|
| 68 |
+
room_map,
|
| 69 |
+
(round(corner[0]), round(corner[1])),
|
| 70 |
+
(round(preds[i + 1][0]), round(preds[i + 1][1])),
|
| 71 |
+
border_color,
|
| 72 |
+
2,
|
| 73 |
+
)
|
| 74 |
+
cv2.circle(room_map, (round(corner[0]), round(corner[1])), 2, (0, 0, 255), 2)
|
| 75 |
+
# cv2.putText(room_map, str(i), (round(corner[0]), round(corner[1])), cv2.FONT_HERSHEY_SIMPLEX,
|
| 76 |
+
# 0.4, (0, 255, 0), 1, cv2.LINE_AA)
|
| 77 |
+
|
| 78 |
+
# Draw white background box with transparency
|
| 79 |
+
# overlay = room_map.copy()
|
| 80 |
+
# cv2.addWeighted(overlay, 0.7, room_map, 0.3, 0, room_map) # 70% opacity
|
| 81 |
+
|
| 82 |
+
# Draw text
|
| 83 |
+
if plot_text:
|
| 84 |
+
cv2.rectangle(
|
| 85 |
+
room_map,
|
| 86 |
+
(centroid_x - text_width // 2 - 2, centroid_y - text_height // 2 - 2),
|
| 87 |
+
(centroid_x + text_width // 2 + 2, centroid_y + text_height // 2 + 2),
|
| 88 |
+
(255, 255, 255), # (0, 0, 0),
|
| 89 |
+
-1,
|
| 90 |
+
) # Filled rectangle
|
| 91 |
+
cv2.putText(
|
| 92 |
+
room_map,
|
| 93 |
+
text,
|
| 94 |
+
(centroid_x - text_width // 2, centroid_y + text_height // 2),
|
| 95 |
+
font,
|
| 96 |
+
font_scale,
|
| 97 |
+
(0, 100, 0),
|
| 98 |
+
thickness,
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
return room_map
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def plot_density_map(sample, image_size, room_polys, pred_room_label_per_scene, plot_text=True):
|
| 105 |
+
if not isinstance(sample, np.ndarray):
|
| 106 |
+
density_map = np.transpose(sample.cpu().numpy(), [1, 2, 0])
|
| 107 |
+
# # Convert to grayscale if not already
|
| 108 |
+
# if density_map.shape[2] > 1:
|
| 109 |
+
# density_map = cv2.cvtColor(density_map, cv2.COLOR_RGB2GRAY)[:, :, np.newaxis]
|
| 110 |
+
else:
|
| 111 |
+
density_map = sample
|
| 112 |
+
if density_map.shape[2] == 3:
|
| 113 |
+
density_map = density_map * (image_size - 1)
|
| 114 |
+
else:
|
| 115 |
+
density_map = np.repeat(density_map, 3, axis=2) * (image_size - 1)
|
| 116 |
+
pred_room_map = np.zeros([image_size, image_size, 3])
|
| 117 |
+
|
| 118 |
+
for room_poly, room_id in zip(room_polys, pred_room_label_per_scene):
|
| 119 |
+
pred_room_map = plot_room_map(
|
| 120 |
+
np.array(room_poly), pred_room_map, room_id, im_size=image_size, plot_text=plot_text
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
alpha = 0.4 # Adjust for desired transparency
|
| 124 |
+
pred_room_map = cv2.addWeighted(density_map.astype(np.uint8), alpha, pred_room_map.astype(np.uint8), 1 - alpha, 0)
|
| 125 |
+
return pred_room_map
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def is_clockwise(points):
|
| 129 |
+
# points is a list of 2d points.
|
| 130 |
+
assert len(points) > 0
|
| 131 |
+
s = 0.0
|
| 132 |
+
for p1, p2 in zip(points, points[1:] + [points[0]]):
|
| 133 |
+
s += (p2[0] - p1[0]) * (p2[1] + p1[1])
|
| 134 |
+
return s > 0.0
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def resort_corners(corners):
|
| 138 |
+
# re-find the starting point and sort corners clockwisely
|
| 139 |
+
x_y_square_sum = corners[:, 0] ** 2 + corners[:, 1] ** 2
|
| 140 |
+
start_corner_idx = np.argmin(x_y_square_sum)
|
| 141 |
+
|
| 142 |
+
corners_sorted = np.concatenate([corners[start_corner_idx:], corners[:start_corner_idx]])
|
| 143 |
+
|
| 144 |
+
## sort points clockwise
|
| 145 |
+
if not is_clockwise(corners_sorted[:, :2].tolist()):
|
| 146 |
+
corners_sorted[1:] = np.flip(corners_sorted[1:], 0)
|
| 147 |
+
|
| 148 |
+
return corners
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def create_coco_bounding_box(bb_x, bb_y, image_width, image_height, bound_pad=2):
|
| 152 |
+
bb_x = np.unique(bb_x)
|
| 153 |
+
bb_y = np.unique(bb_y)
|
| 154 |
+
bb_x_min = np.maximum(np.min(bb_x) - bound_pad, 0)
|
| 155 |
+
bb_y_min = np.maximum(np.min(bb_y) - bound_pad, 0)
|
| 156 |
+
|
| 157 |
+
bb_x_max = np.minimum(np.max(bb_x) + bound_pad, image_width - 1)
|
| 158 |
+
bb_y_max = np.minimum(np.max(bb_y) + bound_pad, image_height - 1)
|
| 159 |
+
|
| 160 |
+
bb_width = bb_x_max - bb_x_min
|
| 161 |
+
bb_height = bb_y_max - bb_y_min
|
| 162 |
+
|
| 163 |
+
coco_bb = [bb_x_min, bb_y_min, bb_width, bb_height]
|
| 164 |
+
return coco_bb
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def prepare_dict():
|
| 168 |
+
save_dict = {"images": [], "annotations": [], "categories": []}
|
| 169 |
+
for key, value in ID2CLASS.items():
|
| 170 |
+
if key == 0:
|
| 171 |
+
continue
|
| 172 |
+
type_dict = {"supercategory": "room", "id": key, "name": value}
|
| 173 |
+
save_dict["categories"].append(type_dict)
|
| 174 |
+
return save_dict
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def get_args_parser():
|
| 178 |
+
parser = argparse.ArgumentParser()
|
| 179 |
+
parser.add_argument(
|
| 180 |
+
"--dataset_path",
|
| 181 |
+
type=str,
|
| 182 |
+
required=True,
|
| 183 |
+
help="Path to the dataset directory",
|
| 184 |
+
)
|
| 185 |
+
parser.add_argument(
|
| 186 |
+
"--output_dir",
|
| 187 |
+
type=str,
|
| 188 |
+
required=True,
|
| 189 |
+
help="Path to the dataset directory",
|
| 190 |
+
)
|
| 191 |
+
# Add more arguments as needed
|
| 192 |
+
return parser
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def visualize_room_polygons(room_polygons, room_classes, image_size=512, save_path="cubicasa_debug.png"):
|
| 196 |
+
"""
|
| 197 |
+
Visualize the extracted room polygons.
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
room_polygons: Dictionary of room polygons as returned by extract_room_polygons
|
| 201 |
+
figsize: Figure size for the plot
|
| 202 |
+
"""
|
| 203 |
+
# Set figure size to exactly 256x256 pixels
|
| 204 |
+
dpi = 100 # Standard screen DPI
|
| 205 |
+
figsize = (image_size / dpi, image_size / dpi) # Convert pixels to inches
|
| 206 |
+
class_names = [v for k, v in ID2CLASS.items()]
|
| 207 |
+
|
| 208 |
+
# Get unique classes from the mask
|
| 209 |
+
unique_classes = list(ID2CLASS.keys())
|
| 210 |
+
|
| 211 |
+
# Create a discrete colormap
|
| 212 |
+
cmap = plt.cm.get_cmap("gist_ncar", 256) # nipy_spectral
|
| 213 |
+
norm = np.linspace(0, 1, 13) # int(max(unique_classes))+1
|
| 214 |
+
|
| 215 |
+
fig = plt.figure(figsize=figsize, dpi=dpi)
|
| 216 |
+
ax = fig.add_axes([0, 0, 1, 1])
|
| 217 |
+
ax.set_xlim(0, image_size)
|
| 218 |
+
ax.set_ylim(0, image_size)
|
| 219 |
+
ax.set_aspect("equal")
|
| 220 |
+
ax.axis("off")
|
| 221 |
+
|
| 222 |
+
# Plot each room polygon and fill with color
|
| 223 |
+
for polygon, room_cls in zip(room_polygons, room_classes):
|
| 224 |
+
polygon_array = np.array(polygon).copy()
|
| 225 |
+
polygon_array[:, 1] = image_size - 1 - polygon_array[:, 1] # flip
|
| 226 |
+
# Fill the polygon with its class color
|
| 227 |
+
color = cmap(norm[int(room_cls)])
|
| 228 |
+
ax.fill(polygon_array[:, 0], polygon_array[:, 1], color=color, alpha=0.4, zorder=1)
|
| 229 |
+
# Draw the polygon border
|
| 230 |
+
ax.plot(polygon_array[:, 0], polygon_array[:, 1], "k-", linewidth=2, zorder=2)
|
| 231 |
+
|
| 232 |
+
# Add room ID label at the centroid
|
| 233 |
+
centroid_x = np.mean(polygon_array[:, 0])
|
| 234 |
+
centroid_y = np.mean(polygon_array[:, 1])
|
| 235 |
+
ax.text(
|
| 236 |
+
centroid_x,
|
| 237 |
+
centroid_y,
|
| 238 |
+
str(room_cls),
|
| 239 |
+
fontsize=12,
|
| 240 |
+
ha="center",
|
| 241 |
+
va="center",
|
| 242 |
+
bbox=dict(facecolor="white", alpha=0.7),
|
| 243 |
+
zorder=3,
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
# Create custom legend elements
|
| 247 |
+
legend_elements = []
|
| 248 |
+
for i, cls in enumerate(sorted(unique_classes)):
|
| 249 |
+
color = cmap(norm[int(cls)])
|
| 250 |
+
cls_name = f"{int(cls)}_{class_names[int(cls)]}"
|
| 251 |
+
legend_elements.append(Patch(facecolor=color, edgecolor="black", label=f"{cls_name}", alpha=0.6))
|
| 252 |
+
ax.legend(
|
| 253 |
+
handles=legend_elements,
|
| 254 |
+
loc="best",
|
| 255 |
+
title="Classes",
|
| 256 |
+
fontsize=10,
|
| 257 |
+
markerscale=1,
|
| 258 |
+
title_fontsize=12,
|
| 259 |
+
framealpha=0.5,
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
plt.tight_layout(pad=0)
|
| 263 |
+
fig.savefig(save_path, bbox_inches="tight", pad_inches=0)
|
| 264 |
+
plt.close()
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def process_floorplan(image_set, split, source_data_path, save_dir, save_aux_dir, vis_fp=False):
|
| 268 |
+
img, target = image_set
|
| 269 |
+
img = img * torch.tensor(std)[:, None, None] + torch.tensor(mean)[:, None, None] # unnormalize
|
| 270 |
+
graph = tensors_to_graphs_batch([target["graph"]])
|
| 271 |
+
del target["graph"]
|
| 272 |
+
|
| 273 |
+
tgt_this_preds = []
|
| 274 |
+
tgt_this_edges = []
|
| 275 |
+
for _ in range(len(target["points"])):
|
| 276 |
+
tgt_p_d = {}
|
| 277 |
+
tgt_p_d["scores"] = torch.tensor(1.0000, device="cpu")
|
| 278 |
+
tgt_p_d["points"] = target["unnormalized_points"][_]
|
| 279 |
+
tgt_p_d["edges"] = target["edges"][_]
|
| 280 |
+
tgt_p_d["size"] = target["size"]
|
| 281 |
+
if "semantic_left_up" in target:
|
| 282 |
+
tgt_p_d["semantic_left_up"] = target["semantic_left_up"][_]
|
| 283 |
+
tgt_p_d["semantic_right_up"] = target["semantic_right_up"][_]
|
| 284 |
+
tgt_p_d["semantic_right_down"] = target["semantic_right_down"][_]
|
| 285 |
+
tgt_p_d["semantic_left_down"] = target["semantic_left_down"][_]
|
| 286 |
+
tgt_this_preds.append(tgt_p_d)
|
| 287 |
+
for __ in range(4):
|
| 288 |
+
adj = graph[0][tuple(tgt_p_d["points"].tolist())][__]
|
| 289 |
+
if adj != (-1, -1):
|
| 290 |
+
tgt_p_d1 = tgt_p_d
|
| 291 |
+
tgt_p_d2 = {}
|
| 292 |
+
indx = 99999
|
| 293 |
+
for ___, up in enumerate(target["unnormalized_points"].tolist()):
|
| 294 |
+
if abs(up[0] - adj[0]) + abs(up[1] - adj[1]) <= 2:
|
| 295 |
+
indx = ___
|
| 296 |
+
break
|
| 297 |
+
# assert indx != 99999
|
| 298 |
+
if indx == 99999: # No match found
|
| 299 |
+
# Log a warning or skip this iteration
|
| 300 |
+
print(f"Warning: No match found for adj {adj}")
|
| 301 |
+
continue # Skip to the next iteration
|
| 302 |
+
# tgt_p_d2['scores'] = torch.tensor(1.0000, device='cuda:0')
|
| 303 |
+
tgt_p_d2["points"] = target["unnormalized_points"][indx]
|
| 304 |
+
tgt_p_d2["edges"] = target["edges"][indx]
|
| 305 |
+
tgt_p_d2["size"] = target["size"]
|
| 306 |
+
if "semantic_left_up" in target:
|
| 307 |
+
tgt_p_d2["semantic_left_up"] = target["semantic_left_up"][indx]
|
| 308 |
+
tgt_p_d2["semantic_right_up"] = target["semantic_right_up"][indx]
|
| 309 |
+
tgt_p_d2["semantic_right_down"] = target["semantic_right_down"][indx]
|
| 310 |
+
tgt_p_d2["semantic_left_down"] = target["semantic_left_down"][indx]
|
| 311 |
+
tgt_e_l = (tgt_p_d1, tgt_p_d2)
|
| 312 |
+
if not edge_inside((tgt_p_d2, tgt_p_d1), tgt_this_edges):
|
| 313 |
+
tgt_this_edges.append(tgt_e_l)
|
| 314 |
+
tgt = [(tgt_this_preds, [], tgt_this_edges)]
|
| 315 |
+
target_d_rev, target_simple_cycles, target_results = get_cycle_basis_and_semantic((2, 999999, tgt))
|
| 316 |
+
|
| 317 |
+
# convert to coco format
|
| 318 |
+
polys_list = []
|
| 319 |
+
polys_semantic_list = []
|
| 320 |
+
output_json = []
|
| 321 |
+
|
| 322 |
+
image_width, image_height = target["size"][0].item(), target["size"][1].item()
|
| 323 |
+
filename = target["file_name"].split(".")[0]
|
| 324 |
+
img_id = int(target["image_id"])
|
| 325 |
+
|
| 326 |
+
img_dict = {}
|
| 327 |
+
img_dict["file_name"] = str(img_id).zfill(6) + ".png"
|
| 328 |
+
img_dict["id"] = img_id
|
| 329 |
+
img_dict["width"] = image_width
|
| 330 |
+
img_dict["height"] = image_height
|
| 331 |
+
save_dict = prepare_dict()
|
| 332 |
+
|
| 333 |
+
os.makedirs(os.path.join(save_dir, split), exist_ok=True)
|
| 334 |
+
os.makedirs(f"{save_dir}/{split}_jsons/", exist_ok=True)
|
| 335 |
+
json_path = f"{save_dir}/{split}_jsons/{str(img_id).zfill(6)}.json"
|
| 336 |
+
|
| 337 |
+
for instance_id, (poly, poly_cls) in enumerate(zip(target_simple_cycles, target_results)):
|
| 338 |
+
t = [(int(pt[0]), int(pt[1])) for pt in poly]
|
| 339 |
+
class_id = int(poly_cls)
|
| 340 |
+
|
| 341 |
+
polys_list.append(t)
|
| 342 |
+
polys_semantic_list.append(class_id)
|
| 343 |
+
|
| 344 |
+
poly_shapely = Polygon(t)
|
| 345 |
+
area = poly_shapely.area
|
| 346 |
+
coco_seg_poly = []
|
| 347 |
+
polygon = np.array(t)
|
| 348 |
+
poly_sorted = resort_corners(polygon)
|
| 349 |
+
|
| 350 |
+
for p in poly_sorted:
|
| 351 |
+
coco_seg_poly += list(p)
|
| 352 |
+
|
| 353 |
+
if area < 100:
|
| 354 |
+
continue
|
| 355 |
+
|
| 356 |
+
if class_id not in ID2CLASS:
|
| 357 |
+
print(f"Warning: Class ID {class_id} not found in ID2CLASS mapping. Skipping instance.")
|
| 358 |
+
continue
|
| 359 |
+
|
| 360 |
+
# Slightly wider bounding box
|
| 361 |
+
rectangle_shapely = poly_shapely.envelope
|
| 362 |
+
bb_x, bb_y = rectangle_shapely.exterior.xy
|
| 363 |
+
coco_bb = create_coco_bounding_box(bb_x, bb_y, image_width, image_height, bound_pad=2)
|
| 364 |
+
|
| 365 |
+
output_json.append(
|
| 366 |
+
{
|
| 367 |
+
"image_id": img_id,
|
| 368 |
+
"segmentation": [coco_seg_poly],
|
| 369 |
+
"category_id": class_id,
|
| 370 |
+
"id": instance_id,
|
| 371 |
+
"area": area,
|
| 372 |
+
"bbox": coco_bb,
|
| 373 |
+
"iscrowd": 0,
|
| 374 |
+
}
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
if vis_fp:
|
| 378 |
+
visualize_room_polygons(
|
| 379 |
+
polys_list,
|
| 380 |
+
polys_semantic_list,
|
| 381 |
+
image_size=image_width,
|
| 382 |
+
save_path=os.path.join(save_aux_dir, str(img_id).zfill(6) + ".png"),
|
| 383 |
+
)
|
| 384 |
+
room_map = plot_density_map(
|
| 385 |
+
img,
|
| 386 |
+
image_width,
|
| 387 |
+
polys_list,
|
| 388 |
+
polys_semantic_list,
|
| 389 |
+
plot_text=False,
|
| 390 |
+
)
|
| 391 |
+
cv2.imwrite(os.path.join(save_aux_dir, str(img_id).zfill(6) + "_density_map.png"), room_map)
|
| 392 |
+
|
| 393 |
+
print(f"Processed image {img_id} with {len(output_json)} instances.")
|
| 394 |
+
# print(f"Class: {target_results}")
|
| 395 |
+
# min_class_id = min(target_results)
|
| 396 |
+
# max_class_id = max(target_results)
|
| 397 |
+
# if max_class_id == 12:
|
| 398 |
+
# breakpoint()
|
| 399 |
+
# print(f"Min class ID: {min_class_id}, Max class ID: {max_class_id}")
|
| 400 |
+
save_dict["images"].append(img_dict)
|
| 401 |
+
save_dict["annotations"] += output_json
|
| 402 |
+
with open(json_path, "w") as json_file:
|
| 403 |
+
# Convert all numpy and torch types to native Python types for JSON serialization
|
| 404 |
+
def convert(o):
|
| 405 |
+
if isinstance(o, (np.integer, np.int32, np.int64)):
|
| 406 |
+
return int(o)
|
| 407 |
+
if isinstance(o, (np.floating, np.float32, np.float64)):
|
| 408 |
+
return float(o)
|
| 409 |
+
if isinstance(o, (np.ndarray,)):
|
| 410 |
+
return o.tolist()
|
| 411 |
+
if isinstance(o, torch.Tensor):
|
| 412 |
+
return o.item() if o.numel() == 1 else o.tolist()
|
| 413 |
+
return str(o)
|
| 414 |
+
|
| 415 |
+
json.dump(save_dict, json_file, default=convert)
|
| 416 |
+
|
| 417 |
+
# rename image file
|
| 418 |
+
shutil.copy(
|
| 419 |
+
os.path.join(source_data_path, split, filename + ".png"),
|
| 420 |
+
os.path.join(save_dir, split, str(img_id).zfill(6) + ".png"),
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
# Write mapping from source file name to target file name (safe for parallel)
|
| 424 |
+
mapping_line = f"{filename} {str(img_id).zfill(6)}\n"
|
| 425 |
+
# Each process writes to its own temp file
|
| 426 |
+
pid = os.getpid()
|
| 427 |
+
os.makedirs(os.path.join(save_dir, f"{split}_logs"), exist_ok=True)
|
| 428 |
+
mapping_file = os.path.join(save_dir, f"{split}_logs", f"{split}_file_mapping_{pid}.txt")
|
| 429 |
+
with open(mapping_file, "a") as f:
|
| 430 |
+
f.write(mapping_line)
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
if __name__ == "__main__":
|
| 434 |
+
args = get_args_parser().parse_args()
|
| 435 |
+
torch.set_printoptions(threshold=np.inf, linewidth=999999)
|
| 436 |
+
np.set_printoptions(threshold=np.inf, linewidth=999999)
|
| 437 |
+
gc.collect()
|
| 438 |
+
torch.cuda.empty_cache()
|
| 439 |
+
|
| 440 |
+
def wrapper(scene_id):
|
| 441 |
+
try:
|
| 442 |
+
image_set = dataset[scene_id]
|
| 443 |
+
except Exception as e:
|
| 444 |
+
print(f"Error processing scene {scene_id}: {e}. Skipping...")
|
| 445 |
+
return
|
| 446 |
+
process_floorplan(image_set, split, args.dataset_path, args.output_dir, save_aux_dir, vis_fp=scene_id < 100)
|
| 447 |
+
|
| 448 |
+
def worker_init(dataset_obj):
|
| 449 |
+
# Store dataset as global to avoid pickling issues
|
| 450 |
+
global dataset
|
| 451 |
+
dataset = dataset_obj
|
| 452 |
+
|
| 453 |
+
splits = ["train", "val", "test"]
|
| 454 |
+
for split in splits:
|
| 455 |
+
dataset = MyDataset(
|
| 456 |
+
args.dataset_path + f"/{split}",
|
| 457 |
+
args.dataset_path + "/annot_json" + f"/instances_{split}.json",
|
| 458 |
+
extract_roi=False,
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
save_aux_dir = os.path.join(args.output_dir, f"{split}_aux")
|
| 462 |
+
os.makedirs(save_aux_dir, exist_ok=True)
|
| 463 |
+
|
| 464 |
+
# for i, image_set in enumerate(tqdm(dataset)):
|
| 465 |
+
# save_aux_dir = os.path.join(args.output_dir, f"{split}_aux")
|
| 466 |
+
# os.makedirs(save_aux_dir, exist_ok=True)
|
| 467 |
+
# process_floorplan(image_set, split, args.dataset_path, args.output_dir, save_aux_dir, vis_fp=i < 100)
|
| 468 |
+
|
| 469 |
+
num_processes = 16
|
| 470 |
+
with Pool(num_processes, initializer=worker_init, initargs=(dataset,)) as p:
|
| 471 |
+
indices = range(len(dataset))
|
| 472 |
+
list(tqdm(p.imap(wrapper, indices), total=len(dataset)))
|
data_preprocess/raster2graph/dataset.py
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
from collections import defaultdict
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.multiprocessing
|
| 9 |
+
import torch.utils.data
|
| 10 |
+
import torchvision.transforms.functional as F
|
| 11 |
+
from PIL import Image
|
| 12 |
+
from torch.utils.data import Dataset
|
| 13 |
+
from util.data_utils import l1_dist
|
| 14 |
+
from util.graph_utils import graph_to_tensor
|
| 15 |
+
from util.image_id_dict import d
|
| 16 |
+
from util.mean_std import mean, std
|
| 17 |
+
from util.semantics_dict import semantics_dict
|
| 18 |
+
|
| 19 |
+
torch.multiprocessing.set_sharing_strategy("file_system")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class MyDataset(Dataset):
|
| 23 |
+
def __init__(self, img_path, annot_path, extract_roi, image_size=512):
|
| 24 |
+
self.img_path = img_path
|
| 25 |
+
self.quadtree_path = "/".join(img_path.split("/")[:-1]) + "/annot_npy"
|
| 26 |
+
self.mode = img_path.split("/")[-1]
|
| 27 |
+
self.image_size = image_size
|
| 28 |
+
|
| 29 |
+
# load annotation
|
| 30 |
+
with open(annot_path, "r") as f:
|
| 31 |
+
dataset = json.load(f)
|
| 32 |
+
# images
|
| 33 |
+
self.imgs = {}
|
| 34 |
+
for img in dataset["images"]:
|
| 35 |
+
self.imgs[img["id"]] = img
|
| 36 |
+
self.imgToAnns = defaultdict(list)
|
| 37 |
+
for ann in dataset["annotations"]:
|
| 38 |
+
self.imgToAnns[ann["image_id"]].append(ann)
|
| 39 |
+
self.ids = list(sorted(self.imgs.keys()))
|
| 40 |
+
if "0c-10-c468a57377ff8ef63d3b26a6d1fa-0002" in self.ids:
|
| 41 |
+
self.ids.remove("0c-10-c468a57377ff8ef63d3b26a6d1fa-0002")
|
| 42 |
+
if "0c-10-8486f08035ba152d5244ac54099c-0001" in self.ids:
|
| 43 |
+
self.ids.remove("0c-10-8486f08035ba152d5244ac54099c-0001")
|
| 44 |
+
|
| 45 |
+
def __getitem__(self, index):
|
| 46 |
+
img_id = self.ids[index]
|
| 47 |
+
img_file_name = self.imgs[img_id]["file_name"].replace(".jpg", ".png")
|
| 48 |
+
img = Image.open(os.path.join(self.img_path, img_file_name)).convert("RGB")
|
| 49 |
+
image_scale = self.image_size / img.size[0]
|
| 50 |
+
if img.size[0] != self.image_size or img.size[1] != self.image_size:
|
| 51 |
+
img = img.resize((self.image_size, self.image_size), Image.BILINEAR)
|
| 52 |
+
|
| 53 |
+
if 1:
|
| 54 |
+
# get structure annotations
|
| 55 |
+
anns = self.imgToAnns[img_id]
|
| 56 |
+
new_anns = []
|
| 57 |
+
for ann in anns:
|
| 58 |
+
new_ann = copy.deepcopy(ann)
|
| 59 |
+
new_ann["point"] = [int(ann["point"][0] * image_scale), int(ann["point"][1] * image_scale)]
|
| 60 |
+
new_anns.append(new_ann)
|
| 61 |
+
target = {"image_id": img_id, "annotations": new_anns}
|
| 62 |
+
orig_quadtree = np.load(
|
| 63 |
+
os.path.join(self.quadtree_path, img_file_name[:-4] + ".npy"), allow_pickle=True
|
| 64 |
+
).item()["quatree"][0]
|
| 65 |
+
quadtree = {}
|
| 66 |
+
for k, v in orig_quadtree.items():
|
| 67 |
+
new_k = k
|
| 68 |
+
new_v = []
|
| 69 |
+
for pos in v:
|
| 70 |
+
new_pos = (int(pos[0] * image_scale), int(pos[1] * image_scale))
|
| 71 |
+
new_v.append(new_pos)
|
| 72 |
+
quadtree[new_k] = new_v
|
| 73 |
+
|
| 74 |
+
orig_graph = np.load(
|
| 75 |
+
os.path.join(self.quadtree_path, img_file_name[:-4] + ".npy"), allow_pickle=True
|
| 76 |
+
).item()
|
| 77 |
+
del orig_graph["quatree"]
|
| 78 |
+
new_graph = {}
|
| 79 |
+
for k, v in orig_graph.items():
|
| 80 |
+
new_k = (int(k[0] * image_scale), int(k[1] * image_scale))
|
| 81 |
+
new_v = []
|
| 82 |
+
for adj in v:
|
| 83 |
+
if adj == (-1, -1):
|
| 84 |
+
new_v.append((-1, -1))
|
| 85 |
+
else:
|
| 86 |
+
new_v.append((int(adj[0] * image_scale), int(adj[1] * image_scale)))
|
| 87 |
+
new_graph[new_k] = new_v
|
| 88 |
+
|
| 89 |
+
target_layers = []
|
| 90 |
+
for layer, layer_points in quadtree.items():
|
| 91 |
+
target_layer = []
|
| 92 |
+
for layer_point in layer_points:
|
| 93 |
+
for target_i in target["annotations"]:
|
| 94 |
+
if l1_dist(target_i["point"], list(layer_point)) <= 2:
|
| 95 |
+
target_layer.append(target_i)
|
| 96 |
+
break
|
| 97 |
+
target_layers.extend(target_layer)
|
| 98 |
+
layer_indices = []
|
| 99 |
+
count = 0
|
| 100 |
+
for k, v in quadtree.items():
|
| 101 |
+
if k == 0:
|
| 102 |
+
layer_indices.append(0)
|
| 103 |
+
else:
|
| 104 |
+
layer_indices.append(count)
|
| 105 |
+
count += len(v)
|
| 106 |
+
|
| 107 |
+
image_id = torch.tensor([d[img_id]])
|
| 108 |
+
|
| 109 |
+
points = [obj["point"] for obj in target_layers]
|
| 110 |
+
points = torch.as_tensor(points, dtype=torch.int64).reshape(-1, 2)
|
| 111 |
+
edges = [obj["edge_code"] for obj in target_layers]
|
| 112 |
+
edges = torch.tensor(edges, dtype=torch.int64)
|
| 113 |
+
|
| 114 |
+
# get semantic annotations
|
| 115 |
+
semantic_left_up = [semantics_dict[obj["semantic"][0]] for obj in target_layers]
|
| 116 |
+
semantic_right_up = [semantics_dict[obj["semantic"][1]] for obj in target_layers]
|
| 117 |
+
semantic_right_down = [semantics_dict[obj["semantic"][2]] for obj in target_layers]
|
| 118 |
+
semantic_left_down = [semantics_dict[obj["semantic"][3]] for obj in target_layers]
|
| 119 |
+
semantic_left_up = torch.tensor(semantic_left_up, dtype=torch.int64)
|
| 120 |
+
semantic_right_up = torch.tensor(semantic_right_up, dtype=torch.int64)
|
| 121 |
+
semantic_right_down = torch.tensor(semantic_right_down, dtype=torch.int64)
|
| 122 |
+
semantic_left_down = torch.tensor(semantic_left_down, dtype=torch.int64)
|
| 123 |
+
|
| 124 |
+
# annotations
|
| 125 |
+
target = {}
|
| 126 |
+
target["edges"] = edges
|
| 127 |
+
target["file_name"] = img_file_name
|
| 128 |
+
target["image_id"] = image_id
|
| 129 |
+
target["size"] = torch.as_tensor([img.size[1], img.size[0]])
|
| 130 |
+
|
| 131 |
+
target["semantic_left_up"] = semantic_left_up
|
| 132 |
+
target["semantic_right_up"] = semantic_right_up
|
| 133 |
+
target["semantic_right_down"] = semantic_right_down
|
| 134 |
+
target["semantic_left_down"] = semantic_left_down
|
| 135 |
+
|
| 136 |
+
# get image
|
| 137 |
+
img = F.to_tensor(img)
|
| 138 |
+
img = F.normalize(img, mean=mean, std=std)
|
| 139 |
+
target["unnormalized_points"] = points
|
| 140 |
+
# normalize
|
| 141 |
+
points = points / torch.tensor([img.shape[2], img.shape[1]], dtype=torch.float32)
|
| 142 |
+
target["points"] = points
|
| 143 |
+
target["layer_indices"] = torch.tensor(layer_indices)
|
| 144 |
+
|
| 145 |
+
target["graph"] = graph_to_tensor(new_graph)
|
| 146 |
+
|
| 147 |
+
return img, target
|
| 148 |
+
|
| 149 |
+
def __len__(self):
|
| 150 |
+
return len(self.ids)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class MyDataset2(Dataset):
|
| 154 |
+
def __init__(self, img_path, annot_path, extract_roi, disable_sem_info=False):
|
| 155 |
+
self.disable_sem_info = disable_sem_info
|
| 156 |
+
self.img_path = img_path
|
| 157 |
+
self.quadtree_path = "/".join(img_path.split("/")[:-1]) + "/annotations_npy/" + img_path.split("/")[-1]
|
| 158 |
+
self.edgecode_path = "/".join(img_path.split("/")[:-1]) + "/annotations_edge/" + img_path.split("/")[-1]
|
| 159 |
+
self.mode = img_path.split("/")[-1]
|
| 160 |
+
|
| 161 |
+
available_ids = {int(x.replace(".npy", "")) for x in os.listdir(self.quadtree_path)}
|
| 162 |
+
|
| 163 |
+
# load annotation
|
| 164 |
+
with open(annot_path, "r") as f:
|
| 165 |
+
dataset = json.load(f)
|
| 166 |
+
# images
|
| 167 |
+
self.imgs = {}
|
| 168 |
+
for img in dataset["images"]:
|
| 169 |
+
if img["id"] not in available_ids:
|
| 170 |
+
continue
|
| 171 |
+
self.imgs[img["id"]] = img
|
| 172 |
+
self.imgToAnns = defaultdict(list)
|
| 173 |
+
for ann in dataset["annotations"]:
|
| 174 |
+
if ann["image_id"] not in available_ids:
|
| 175 |
+
continue
|
| 176 |
+
self.imgToAnns[ann["image_id"]].append(ann)
|
| 177 |
+
self.ids = list(sorted(self.imgs.keys()))
|
| 178 |
+
|
| 179 |
+
def __getitem__(self, index):
|
| 180 |
+
img_id = self.ids[index]
|
| 181 |
+
img_file_name = self.imgs[int(img_id)]["file_name"]
|
| 182 |
+
img = Image.open(os.path.join(self.img_path, img_file_name)).convert("RGB")
|
| 183 |
+
|
| 184 |
+
if 1:
|
| 185 |
+
# get structure annotations
|
| 186 |
+
# anns = self.imgToAnns[int(img_id)]
|
| 187 |
+
|
| 188 |
+
data = np.load(os.path.join(self.quadtree_path, img_file_name[:-4] + ".npy"), allow_pickle=True).item()
|
| 189 |
+
orig_quadtree = data["quadtree"]
|
| 190 |
+
orig_graph = data["graph"]
|
| 191 |
+
image_points = data["points"]
|
| 192 |
+
|
| 193 |
+
new_anns = []
|
| 194 |
+
for pt in image_points:
|
| 195 |
+
new_ann = {
|
| 196 |
+
"point": [int(pt[0]), int(pt[1])],
|
| 197 |
+
}
|
| 198 |
+
new_anns.append(new_ann)
|
| 199 |
+
target = {"image_id": img_id, "annotations": new_anns}
|
| 200 |
+
|
| 201 |
+
quadtree = {}
|
| 202 |
+
for k, v in orig_quadtree.items():
|
| 203 |
+
new_k = k
|
| 204 |
+
new_v = []
|
| 205 |
+
for pos in v:
|
| 206 |
+
new_pos = (int(pos[0]), int(pos[1]))
|
| 207 |
+
new_v.append(new_pos)
|
| 208 |
+
quadtree[new_k] = new_v
|
| 209 |
+
|
| 210 |
+
new_graph = {}
|
| 211 |
+
for k, v in orig_graph.items():
|
| 212 |
+
new_k = (int(k[0]), int(k[1]))
|
| 213 |
+
new_v = []
|
| 214 |
+
for adj in v:
|
| 215 |
+
if adj == (-1, -1):
|
| 216 |
+
new_v.append((-1, -1))
|
| 217 |
+
else:
|
| 218 |
+
new_v.append((int(adj[0]), int(adj[1])))
|
| 219 |
+
new_graph[new_k] = new_v
|
| 220 |
+
|
| 221 |
+
target_layers = []
|
| 222 |
+
for layer, layer_points in quadtree.items():
|
| 223 |
+
target_layer = []
|
| 224 |
+
for layer_point in layer_points:
|
| 225 |
+
for target_i in target["annotations"]:
|
| 226 |
+
if l1_dist(target_i["point"], list(layer_point)) <= 2:
|
| 227 |
+
target_layer.append(target_i)
|
| 228 |
+
break
|
| 229 |
+
target_layers.extend(target_layer)
|
| 230 |
+
layer_indices = []
|
| 231 |
+
count = 0
|
| 232 |
+
for k, v in quadtree.items():
|
| 233 |
+
if k == 0:
|
| 234 |
+
layer_indices.append(0)
|
| 235 |
+
else:
|
| 236 |
+
layer_indices.append(count)
|
| 237 |
+
count += len(v)
|
| 238 |
+
|
| 239 |
+
image_id = torch.tensor([int(img_id)])
|
| 240 |
+
|
| 241 |
+
points = [obj["point"] for obj in target_layers]
|
| 242 |
+
with open(os.path.join(self.edgecode_path, img_file_name[:-4] + ".json"), "r") as f:
|
| 243 |
+
edge2code = json.load(f)
|
| 244 |
+
edge2code = {
|
| 245 |
+
tuple(map(lambda x: int(float(x)), key.strip("()").split(", "))): value
|
| 246 |
+
for key, value in edge2code.items()
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
edges = [edge2code[(int(pt[0]), int(pt[1]))] for pt in points]
|
| 250 |
+
points = torch.as_tensor(points, dtype=torch.int64).reshape(-1, 2)
|
| 251 |
+
edges = torch.tensor(edges, dtype=torch.int64)
|
| 252 |
+
|
| 253 |
+
# annotations
|
| 254 |
+
target = {}
|
| 255 |
+
target["edges"] = edges
|
| 256 |
+
target["image_id"] = image_id
|
| 257 |
+
target["file_name"] = img_file_name
|
| 258 |
+
target["size"] = torch.as_tensor([img.size[1], img.size[0]])
|
| 259 |
+
|
| 260 |
+
# get semantic annotations
|
| 261 |
+
if not self.disable_sem_info:
|
| 262 |
+
semantic_left_up = [semantics_dict[obj["semantic"][0]] for obj in target_layers]
|
| 263 |
+
semantic_right_up = [semantics_dict[obj["semantic"][1]] for obj in target_layers]
|
| 264 |
+
semantic_right_down = [semantics_dict[obj["semantic"][2]] for obj in target_layers]
|
| 265 |
+
semantic_left_down = [semantics_dict[obj["semantic"][3]] for obj in target_layers]
|
| 266 |
+
semantic_left_up = torch.tensor(semantic_left_up, dtype=torch.int64)
|
| 267 |
+
semantic_right_up = torch.tensor(semantic_right_up, dtype=torch.int64)
|
| 268 |
+
semantic_right_down = torch.tensor(semantic_right_down, dtype=torch.int64)
|
| 269 |
+
semantic_left_down = torch.tensor(semantic_left_down, dtype=torch.int64)
|
| 270 |
+
|
| 271 |
+
target["semantic_left_up"] = semantic_left_up
|
| 272 |
+
target["semantic_right_up"] = semantic_right_up
|
| 273 |
+
target["semantic_right_down"] = semantic_right_down
|
| 274 |
+
target["semantic_left_down"] = semantic_left_down
|
| 275 |
+
|
| 276 |
+
# get image
|
| 277 |
+
img = F.to_tensor(img)
|
| 278 |
+
img = F.normalize(img, mean=mean, std=std)
|
| 279 |
+
target["unnormalized_points"] = points
|
| 280 |
+
# normalize
|
| 281 |
+
points = points / torch.tensor([img.shape[2], img.shape[1]], dtype=torch.float32)
|
| 282 |
+
target["points"] = points
|
| 283 |
+
target["layer_indices"] = torch.tensor(layer_indices)
|
| 284 |
+
|
| 285 |
+
# padding (-1,-1) if not enough 4 neighbors
|
| 286 |
+
for pt, neighbors in new_graph.items():
|
| 287 |
+
if len(neighbors) < 4:
|
| 288 |
+
new_graph[pt].extend([(-1, -1)] * (4 - len(neighbors)))
|
| 289 |
+
elif len(neighbors) > 4:
|
| 290 |
+
new_graph[pt] = neighbors[:4]
|
| 291 |
+
target["graph"] = graph_to_tensor(new_graph)
|
| 292 |
+
|
| 293 |
+
return img, target
|
| 294 |
+
|
| 295 |
+
def __len__(self):
|
| 296 |
+
return len(self.ids)
|
data_preprocess/raster2graph/image_process.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
|
| 9 |
+
parser = argparse.ArgumentParser("Preprocess LIFULL HOMES DATA (HIGH RESOLUTION) Dataset")
|
| 10 |
+
parser.add_argument("--data_root", type=str, default=r"R2G_hr_dataset/", help="path to the root folder of the dataset")
|
| 11 |
+
args = parser.parse_args()
|
| 12 |
+
|
| 13 |
+
SIZE = 512
|
| 14 |
+
MARGIN = 64
|
| 15 |
+
np.set_printoptions(threshold=np.inf, linewidth=999999)
|
| 16 |
+
|
| 17 |
+
# original_images_path = r'E:/LIFULL HOMES DATA (HIGH RESOLUTION)/photo-rent-madori-full-00'
|
| 18 |
+
original_images_path = args.data_root
|
| 19 |
+
|
| 20 |
+
with open(f"{args.data_root}/annot_json/instances_train.json", mode="r") as f_train:
|
| 21 |
+
train_jpgs = [_["file_name"] for _ in json.load(f_train)["images"]]
|
| 22 |
+
with open(f"{args.data_root}/annot_json/instances_val.json", mode="r") as f_val:
|
| 23 |
+
val_jpgs = [_["file_name"] for _ in json.load(f_val)["images"]]
|
| 24 |
+
with open(f"{args.data_root}/instances_test.json", mode="r") as f_test:
|
| 25 |
+
test_jpgs = [_["file_name"] for _ in json.load(f_test)["images"]]
|
| 26 |
+
jpgs = {"train": train_jpgs, "val": val_jpgs, "test": test_jpgs}
|
| 27 |
+
|
| 28 |
+
start_idx = 0
|
| 29 |
+
for mode in ["train", "val", "test"]:
|
| 30 |
+
output_dir = "./" + mode
|
| 31 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 32 |
+
for fnames in [jpgs[mode]]:
|
| 33 |
+
for i in tqdm(range(len(fnames))):
|
| 34 |
+
fn = fnames[i].replace(".jpg", "")
|
| 35 |
+
if os.path.exists(os.path.join(f"{args.data_root}/annot_npy", fn + ".npy")) and os.path.exists(
|
| 36 |
+
os.path.join(f"{args.data_root}/original_vector_boundary", fn + ".npy")
|
| 37 |
+
):
|
| 38 |
+
img_original = Image.open(os.path.join(original_images_path, fn.replace("-", "/") + ".jpg"))
|
| 39 |
+
boundary_path = os.path.join(f"{args.data_root}/original_vector_boundary", fn + ".npy")
|
| 40 |
+
boundary = np.load(boundary_path, allow_pickle=True).item()
|
| 41 |
+
x_min = boundary["x_min"]
|
| 42 |
+
x_max = boundary["x_max"]
|
| 43 |
+
y_min = boundary["y_min"]
|
| 44 |
+
y_max = boundary["y_max"]
|
| 45 |
+
width = x_max - x_min
|
| 46 |
+
mid_width = (x_max + x_min) / 2
|
| 47 |
+
height = y_max - y_min
|
| 48 |
+
mid_height = (y_max + y_min) / 2
|
| 49 |
+
if width > height:
|
| 50 |
+
scale = (SIZE - 2 * MARGIN) / width
|
| 51 |
+
else:
|
| 52 |
+
scale = (SIZE - 2 * MARGIN) / height
|
| 53 |
+
# print(x_min, y_min, x_max, y_max, width, height, scale)
|
| 54 |
+
|
| 55 |
+
original_width, original_height = img_original.size
|
| 56 |
+
new_width = int(original_width * scale)
|
| 57 |
+
new_height = int(original_height * scale)
|
| 58 |
+
scaled_image = img_original.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
| 59 |
+
canvas = Image.new("RGB", (512, 512), (255, 255, 255))
|
| 60 |
+
# print(new_width, new_height)
|
| 61 |
+
x_topleft_offset = int(512 / 2 - mid_width * scale)
|
| 62 |
+
y_topleft_offset = int(512 / 2 - mid_height * scale)
|
| 63 |
+
canvas.paste(scaled_image, (x_topleft_offset, y_topleft_offset))
|
| 64 |
+
|
| 65 |
+
canvas.save(os.path.join(output_dir, fn + ".png"))
|
| 66 |
+
|
| 67 |
+
start_idx += 1
|
data_preprocess/raster2graph/util/data_utils.py
ADDED
|
@@ -0,0 +1,966 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import random
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from util.edges_utils import get_edges_alldirections_rev
|
| 6 |
+
from util.math_utils import clip
|
| 7 |
+
from util.mean_std import mean, std
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def data_to_cuda(samples, targets):
|
| 11 |
+
return samples.to(torch.device("cuda")), [
|
| 12 |
+
{k: v if isinstance(v, str) else v.to(torch.device("cuda")) for k, v in t.items()} for t in targets
|
| 13 |
+
]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_random_layer_targets(targets, gt_layer):
|
| 17 |
+
random_targets = []
|
| 18 |
+
for batch_i, target_i in enumerate(targets):
|
| 19 |
+
random_layer_targets_i = copy.deepcopy(target_i)
|
| 20 |
+
if gt_layer[batch_i] != len(random_layer_targets_i["layer_indices"]) - 1:
|
| 21 |
+
start = random_layer_targets_i["layer_indices"][gt_layer[batch_i]].item()
|
| 22 |
+
end = random_layer_targets_i["layer_indices"][gt_layer[batch_i] + 1].item()
|
| 23 |
+
else:
|
| 24 |
+
start = random_layer_targets_i["layer_indices"][gt_layer[batch_i]].item()
|
| 25 |
+
end = len(random_layer_targets_i["points"])
|
| 26 |
+
random_points_i = random_layer_targets_i["points"][start:end, :]
|
| 27 |
+
random_edges_i = random_layer_targets_i["edges"][start:end]
|
| 28 |
+
random_unnormalized_points_i = random_layer_targets_i["unnormalized_points"][start:end, :]
|
| 29 |
+
random_layer_targets_i["points"] = random_points_i
|
| 30 |
+
random_layer_targets_i["edges"] = random_edges_i
|
| 31 |
+
random_layer_targets_i["unnormalized_points"] = random_unnormalized_points_i
|
| 32 |
+
del random_layer_targets_i["layer_indices"]
|
| 33 |
+
random_targets.append(random_layer_targets_i)
|
| 34 |
+
return random_targets
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def random_layers(targets):
|
| 38 |
+
return [random.randint(0, len(targets[i]["layer_indices"]) - 1) for i in range(len(targets))]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def get_given_layers_random_region(targets, graphs):
|
| 42 |
+
random_regions = []
|
| 43 |
+
for bs_i in range(len(targets)):
|
| 44 |
+
# target
|
| 45 |
+
targets_i = targets[bs_i]
|
| 46 |
+
graphs_i = graphs[bs_i]
|
| 47 |
+
# level 0: start
|
| 48 |
+
start_i = tuple(targets_i["unnormalized_points"][0].tolist())
|
| 49 |
+
|
| 50 |
+
# sampled prob: for a neighborhood, each node is sampled by this probability
|
| 51 |
+
# sampled_prob = 0.0001
|
| 52 |
+
# sampled_prob = random.random()
|
| 53 |
+
sampled_prob = 0.5
|
| 54 |
+
# sampled_prob = 1
|
| 55 |
+
|
| 56 |
+
# sampled nodes
|
| 57 |
+
sampled_points = {}
|
| 58 |
+
for point_tensor in targets_i["unnormalized_points"]:
|
| 59 |
+
pos = tuple(point_tensor.tolist())
|
| 60 |
+
sampled_points[pos] = 0
|
| 61 |
+
# edges of sampled nodes
|
| 62 |
+
sampled_edges = []
|
| 63 |
+
|
| 64 |
+
# nodes number of subgraph
|
| 65 |
+
# sampled_amount = random.randint(0, len(sampled_points) + 2)
|
| 66 |
+
# if sampled_amount in [len(sampled_points) + 1]:
|
| 67 |
+
# sampled_amount = 0
|
| 68 |
+
# if sampled_amount in [len(sampled_points) + 2]:
|
| 69 |
+
# sampled_amount = len(sampled_points)
|
| 70 |
+
sampled_amount = random.randint(0, len(sampled_points)) # TODO: 1~len(sampled_points)
|
| 71 |
+
|
| 72 |
+
# Note that when sampled_prob = 1, the number of sampled nodes must be in 'layer_indices' or be the total number of points to ensure that the entire layers is sampled.
|
| 73 |
+
# equal to BFS
|
| 74 |
+
if sampled_prob == 1:
|
| 75 |
+
l = targets_i["layer_indices"].tolist()
|
| 76 |
+
l.append(len(sampled_points))
|
| 77 |
+
l.append(0)
|
| 78 |
+
l.append(len(sampled_points))
|
| 79 |
+
sampled_amount = l[random.randint(0, len(l) - 1)]
|
| 80 |
+
|
| 81 |
+
# start sampling
|
| 82 |
+
if sampled_amount == 0:
|
| 83 |
+
random_regions.append((sampled_points, sampled_edges))
|
| 84 |
+
continue
|
| 85 |
+
sampled_points[start_i] = 1
|
| 86 |
+
if sampled_amount == 1:
|
| 87 |
+
random_regions.append((sampled_points, sampled_edges))
|
| 88 |
+
continue
|
| 89 |
+
|
| 90 |
+
max_iterations = max(1000, 10 * sampled_amount) # Ensure at least 1000 iterations
|
| 91 |
+
iteration_count = 0
|
| 92 |
+
while sum(sampled_points.values()) < sampled_amount:
|
| 93 |
+
iteration_count += 1
|
| 94 |
+
if iteration_count > max_iterations:
|
| 95 |
+
print("Reached maximum iterations, breaking to avoid infinite loop.")
|
| 96 |
+
break
|
| 97 |
+
all_sampled_points = set([k for k, v in sampled_points.items() if v == 1])
|
| 98 |
+
all_sampled_points_adjs = set()
|
| 99 |
+
for sampled_point in all_sampled_points:
|
| 100 |
+
adj = set([(int(x[0]), int(x[1])) for x in graphs_i[sampled_point]])
|
| 101 |
+
all_sampled_points_adjs = all_sampled_points_adjs.union(adj)
|
| 102 |
+
|
| 103 |
+
if (-1, -1) in all_sampled_points_adjs:
|
| 104 |
+
all_sampled_points_adjs.remove((-1, -1))
|
| 105 |
+
all_sampled_points_adjs = list(all_sampled_points_adjs.difference(all_sampled_points))
|
| 106 |
+
|
| 107 |
+
if not all_sampled_points_adjs: # If no more adjacent points to sample, break
|
| 108 |
+
print("No more adjacent points to sample, breaking the loop.")
|
| 109 |
+
break
|
| 110 |
+
|
| 111 |
+
# shuffle the last layer to let it uniform (no bias of sample order)
|
| 112 |
+
random.shuffle(all_sampled_points_adjs)
|
| 113 |
+
# determine whether to sample nodes in each neighborhood based on probability
|
| 114 |
+
for all_sampled_points_adj_index, all_sampled_points_adj in enumerate(all_sampled_points_adjs):
|
| 115 |
+
all_sampled_points = set([k for k, v in sampled_points.items() if v == 1])
|
| 116 |
+
if sum(sampled_points.values()) == sampled_amount:
|
| 117 |
+
break
|
| 118 |
+
else:
|
| 119 |
+
if 1:
|
| 120 |
+
if random.random() < sampled_prob:
|
| 121 |
+
sampled_points[all_sampled_points_adj] = 1
|
| 122 |
+
# sample edges
|
| 123 |
+
all_pos1s = graphs_i[all_sampled_points_adj]
|
| 124 |
+
pos2 = all_sampled_points_adj
|
| 125 |
+
for pos1 in all_pos1s:
|
| 126 |
+
if pos1 in all_sampled_points:
|
| 127 |
+
sampled_edges.append((pos1, pos2))
|
| 128 |
+
else:
|
| 129 |
+
sampled_points[all_sampled_points_adj] = 0
|
| 130 |
+
random_regions.append((sampled_points, sampled_edges))
|
| 131 |
+
return random_regions
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def get_random_region_targets(given_layers, graphs, targets):
|
| 135 |
+
random_region_targets = []
|
| 136 |
+
for bs_i in range(len(targets)):
|
| 137 |
+
random_region_target = {}
|
| 138 |
+
targets_i = targets[bs_i]
|
| 139 |
+
graphs_i = graphs[bs_i]
|
| 140 |
+
given_layers_i = given_layers[bs_i]
|
| 141 |
+
sampled_points_i, sampled_edges_i = given_layers_i
|
| 142 |
+
|
| 143 |
+
if sum(sampled_points_i.values()) == 0:
|
| 144 |
+
random_region_target["edges"] = targets_i["edges"][:1]
|
| 145 |
+
|
| 146 |
+
if "semantic_left_up" in targets_i:
|
| 147 |
+
random_region_target["semantic_left_up"] = targets_i["semantic_left_up"][:1]
|
| 148 |
+
random_region_target["semantic_right_up"] = targets_i["semantic_right_up"][:1]
|
| 149 |
+
random_region_target["semantic_right_down"] = targets_i["semantic_right_down"][:1]
|
| 150 |
+
random_region_target["semantic_left_down"] = targets_i["semantic_left_down"][:1]
|
| 151 |
+
|
| 152 |
+
random_region_target["image_id"] = targets_i["image_id"]
|
| 153 |
+
random_region_target["size"] = targets_i["size"]
|
| 154 |
+
random_region_target["unnormalized_points"] = targets_i["unnormalized_points"][:1]
|
| 155 |
+
random_region_target["points"] = targets_i["points"][:1]
|
| 156 |
+
random_region_target["last_edges"] = torch.zeros(
|
| 157 |
+
(1,), dtype=targets_i["edges"].dtype, device=targets_i["edges"].device
|
| 158 |
+
)
|
| 159 |
+
random_region_target["this_edges"] = torch.zeros(
|
| 160 |
+
(1,), dtype=targets_i["edges"].dtype, device=targets_i["edges"].device
|
| 161 |
+
)
|
| 162 |
+
random_region_targets.append(random_region_target)
|
| 163 |
+
elif 1 <= sum(sampled_points_i.values()) <= len(sampled_points_i) - 1:
|
| 164 |
+
sampled_points_i_given = set([k for k, v in sampled_points_i.items() if v == 1])
|
| 165 |
+
unnormalized_points = []
|
| 166 |
+
for point, sampled_or_not in sampled_points_i.items():
|
| 167 |
+
if sampled_or_not == 0:
|
| 168 |
+
adjs = graphs_i[point]
|
| 169 |
+
for adj in adjs:
|
| 170 |
+
if adj in sampled_points_i_given:
|
| 171 |
+
unnormalized_points.append(point)
|
| 172 |
+
break
|
| 173 |
+
|
| 174 |
+
if len(unnormalized_points) == 0:
|
| 175 |
+
random_region_target["edges"] = targets_i["edges"][:1]
|
| 176 |
+
|
| 177 |
+
if "semantic_left_up" in targets_i:
|
| 178 |
+
random_region_target["semantic_left_up"] = targets_i["semantic_left_up"][:1]
|
| 179 |
+
random_region_target["semantic_right_up"] = targets_i["semantic_right_up"][:1]
|
| 180 |
+
random_region_target["semantic_right_down"] = targets_i["semantic_right_down"][:1]
|
| 181 |
+
random_region_target["semantic_left_down"] = targets_i["semantic_left_down"][:1]
|
| 182 |
+
|
| 183 |
+
random_region_target["image_id"] = targets_i["image_id"]
|
| 184 |
+
random_region_target["size"] = targets_i["size"]
|
| 185 |
+
random_region_target["unnormalized_points"] = targets_i["unnormalized_points"][:1]
|
| 186 |
+
random_region_target["points"] = targets_i["points"][:1]
|
| 187 |
+
random_region_target["last_edges"] = torch.zeros(
|
| 188 |
+
(1,), dtype=targets_i["edges"].dtype, device=targets_i["edges"].device
|
| 189 |
+
)
|
| 190 |
+
random_region_target["this_edges"] = torch.zeros(
|
| 191 |
+
(1,), dtype=targets_i["edges"].dtype, device=targets_i["edges"].device
|
| 192 |
+
)
|
| 193 |
+
random_region_targets.append(random_region_target)
|
| 194 |
+
continue
|
| 195 |
+
|
| 196 |
+
indices_for_semantic = []
|
| 197 |
+
for unnormalized_point in unnormalized_points:
|
| 198 |
+
for ind, every_point in enumerate(targets_i["unnormalized_points"]):
|
| 199 |
+
every_point = tuple(every_point.tolist())
|
| 200 |
+
if (
|
| 201 |
+
abs(every_point[0] - unnormalized_point[0]) <= 2
|
| 202 |
+
and abs(every_point[1] - unnormalized_point[1]) <= 2
|
| 203 |
+
):
|
| 204 |
+
indices_for_semantic.append(ind)
|
| 205 |
+
# assert len(unnormalized_points) == len(indices_for_semantic)
|
| 206 |
+
semantic_left_up = []
|
| 207 |
+
semantic_right_up = []
|
| 208 |
+
semantic_right_down = []
|
| 209 |
+
semantic_left_down = []
|
| 210 |
+
|
| 211 |
+
if "semantic_left_up" in targets_i:
|
| 212 |
+
for ind in indices_for_semantic:
|
| 213 |
+
semantic_left_up.append(targets_i["semantic_left_up"][ind].item())
|
| 214 |
+
semantic_right_up.append(targets_i["semantic_right_up"][ind].item())
|
| 215 |
+
semantic_right_down.append(targets_i["semantic_right_down"][ind].item())
|
| 216 |
+
semantic_left_down.append(targets_i["semantic_left_down"][ind].item())
|
| 217 |
+
|
| 218 |
+
edges = []
|
| 219 |
+
for unnormalized_point in unnormalized_points:
|
| 220 |
+
edge = ""
|
| 221 |
+
adjs = graphs_i[unnormalized_point]
|
| 222 |
+
for adj in adjs:
|
| 223 |
+
if adj != (-1, -1):
|
| 224 |
+
edge += "1"
|
| 225 |
+
else:
|
| 226 |
+
edge += "0"
|
| 227 |
+
edge = get_edges_alldirections_rev(edge)
|
| 228 |
+
edges.append(edge)
|
| 229 |
+
last_edges = []
|
| 230 |
+
for unnormalized_point in unnormalized_points:
|
| 231 |
+
last_edge = ""
|
| 232 |
+
adjs = graphs_i[unnormalized_point]
|
| 233 |
+
for adj in adjs:
|
| 234 |
+
if adj in sampled_points_i_given:
|
| 235 |
+
last_edge += "1"
|
| 236 |
+
else:
|
| 237 |
+
last_edge += "0"
|
| 238 |
+
last_edge = get_edges_alldirections_rev(last_edge)
|
| 239 |
+
last_edges.append(last_edge)
|
| 240 |
+
this_edges = []
|
| 241 |
+
for unnormalized_point in unnormalized_points:
|
| 242 |
+
this_edge = ""
|
| 243 |
+
adjs = graphs_i[unnormalized_point]
|
| 244 |
+
for adj in adjs:
|
| 245 |
+
if adj in unnormalized_points:
|
| 246 |
+
this_edge += "1"
|
| 247 |
+
else:
|
| 248 |
+
this_edge += "0"
|
| 249 |
+
this_edge = get_edges_alldirections_rev(this_edge)
|
| 250 |
+
this_edges.append(this_edge)
|
| 251 |
+
|
| 252 |
+
random_region_target["edges"] = torch.tensor(
|
| 253 |
+
edges, dtype=targets_i["edges"].dtype, device=targets_i["edges"].device
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
if "semantic_left_up" in targets_i:
|
| 257 |
+
random_region_target["semantic_left_up"] = torch.tensor(
|
| 258 |
+
semantic_left_up,
|
| 259 |
+
dtype=targets_i["semantic_left_up"].dtype,
|
| 260 |
+
device=targets_i["semantic_left_up"].device,
|
| 261 |
+
)
|
| 262 |
+
random_region_target["semantic_right_up"] = torch.tensor(
|
| 263 |
+
semantic_right_up,
|
| 264 |
+
dtype=targets_i["semantic_right_up"].dtype,
|
| 265 |
+
device=targets_i["semantic_right_up"].device,
|
| 266 |
+
)
|
| 267 |
+
random_region_target["semantic_right_down"] = torch.tensor(
|
| 268 |
+
semantic_right_down,
|
| 269 |
+
dtype=targets_i["semantic_right_down"].dtype,
|
| 270 |
+
device=targets_i["semantic_right_down"].device,
|
| 271 |
+
)
|
| 272 |
+
random_region_target["semantic_left_down"] = torch.tensor(
|
| 273 |
+
semantic_left_down,
|
| 274 |
+
dtype=targets_i["semantic_left_down"].dtype,
|
| 275 |
+
device=targets_i["semantic_left_down"].device,
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
random_region_target["image_id"] = targets_i["image_id"]
|
| 279 |
+
random_region_target["size"] = targets_i["size"]
|
| 280 |
+
|
| 281 |
+
# NEW
|
| 282 |
+
# if len(unnormalized_points) == 0:
|
| 283 |
+
# print("Warning: unnormalized_points is empty. Initializing to default value.")
|
| 284 |
+
# unnormalized_points = torch.zeros((1, 2), dtype=targets_i['unnormalized_points'].dtype,
|
| 285 |
+
# device=targets_i['unnormalized_points'].device)
|
| 286 |
+
# else:
|
| 287 |
+
# random_region_target['unnormalized_points'] = torch.tensor(unnormalized_points,
|
| 288 |
+
# dtype=targets_i['unnormalized_points'].dtype,
|
| 289 |
+
# device=targets_i['unnormalized_points'].device)
|
| 290 |
+
|
| 291 |
+
random_region_target["unnormalized_points"] = torch.tensor(
|
| 292 |
+
unnormalized_points,
|
| 293 |
+
dtype=targets_i["unnormalized_points"].dtype,
|
| 294 |
+
device=targets_i["unnormalized_points"].device,
|
| 295 |
+
)
|
| 296 |
+
random_region_target["points"] = (
|
| 297 |
+
torch.tensor(unnormalized_points, dtype=targets_i["points"].dtype, device=targets_i["points"].device)
|
| 298 |
+
/ targets_i["size"]
|
| 299 |
+
)
|
| 300 |
+
random_region_target["last_edges"] = torch.tensor(
|
| 301 |
+
last_edges, dtype=targets_i["edges"].dtype, device=targets_i["edges"].device
|
| 302 |
+
)
|
| 303 |
+
random_region_target["this_edges"] = torch.tensor(
|
| 304 |
+
this_edges, dtype=targets_i["edges"].dtype, device=targets_i["edges"].device
|
| 305 |
+
)
|
| 306 |
+
random_region_targets.append(random_region_target)
|
| 307 |
+
else:
|
| 308 |
+
random_region_target["edges"] = 16 * torch.ones(
|
| 309 |
+
targets_i["edges"][:1].shape, dtype=targets_i["edges"].dtype, device=targets_i["edges"].device
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
if "semantic_left_up" in targets_i:
|
| 313 |
+
random_region_target["semantic_left_up"] = 11 * torch.ones(
|
| 314 |
+
targets_i["semantic_left_up"][:1].shape,
|
| 315 |
+
dtype=targets_i["semantic_left_up"].dtype,
|
| 316 |
+
device=targets_i["semantic_left_up"].device,
|
| 317 |
+
)
|
| 318 |
+
random_region_target["semantic_right_up"] = 11 * torch.ones(
|
| 319 |
+
targets_i["semantic_right_up"][:1].shape,
|
| 320 |
+
dtype=targets_i["semantic_right_up"].dtype,
|
| 321 |
+
device=targets_i["semantic_right_up"].device,
|
| 322 |
+
)
|
| 323 |
+
random_region_target["semantic_right_down"] = 11 * torch.ones(
|
| 324 |
+
targets_i["semantic_right_down"][:1].shape,
|
| 325 |
+
dtype=targets_i["semantic_right_down"].dtype,
|
| 326 |
+
device=targets_i["semantic_right_down"].device,
|
| 327 |
+
)
|
| 328 |
+
random_region_target["semantic_left_down"] = 11 * torch.ones(
|
| 329 |
+
targets_i["semantic_left_down"][:1].shape,
|
| 330 |
+
dtype=targets_i["semantic_left_down"].dtype,
|
| 331 |
+
device=targets_i["semantic_left_down"].device,
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
random_region_target["image_id"] = targets_i["image_id"]
|
| 335 |
+
random_region_target["size"] = targets_i["size"]
|
| 336 |
+
random_region_target["unnormalized_points"] = 505 * torch.ones(
|
| 337 |
+
targets_i["unnormalized_points"][:1].shape,
|
| 338 |
+
dtype=targets_i["unnormalized_points"][:1].dtype,
|
| 339 |
+
device=targets_i["unnormalized_points"][:1].device,
|
| 340 |
+
)
|
| 341 |
+
random_region_target["points"] = (
|
| 342 |
+
505
|
| 343 |
+
* torch.ones(
|
| 344 |
+
targets_i["unnormalized_points"][:1].shape,
|
| 345 |
+
dtype=targets_i["points"][:1].dtype,
|
| 346 |
+
device=targets_i["points"][:1].device,
|
| 347 |
+
)
|
| 348 |
+
) / targets_i["size"]
|
| 349 |
+
random_region_target["last_edges"] = 16 * torch.ones(
|
| 350 |
+
(1,), dtype=targets_i["edges"].dtype, device=targets_i["edges"].device
|
| 351 |
+
)
|
| 352 |
+
random_region_target["this_edges"] = 16 * torch.ones(
|
| 353 |
+
(1,), dtype=targets_i["edges"].dtype, device=targets_i["edges"].device
|
| 354 |
+
)
|
| 355 |
+
random_region_targets.append(random_region_target)
|
| 356 |
+
|
| 357 |
+
return random_region_targets
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
def random_pertubation(sampled_points_i, sampled_edges_i):
|
| 361 |
+
random_pertube_map = {}
|
| 362 |
+
sigma = 2
|
| 363 |
+
pertube_threshold = 5
|
| 364 |
+
for sampled_point in sampled_points_i:
|
| 365 |
+
random_pertube_map[sampled_point] = (
|
| 366 |
+
sampled_point[0] + clip(int(random.gauss(0, sigma)), -1 * pertube_threshold, pertube_threshold),
|
| 367 |
+
sampled_point[1] + clip(int(random.gauss(0, sigma)), -1 * pertube_threshold, pertube_threshold),
|
| 368 |
+
)
|
| 369 |
+
new_sampled_points_i = {}
|
| 370 |
+
new_sampled_edges_i = []
|
| 371 |
+
for sampled_point in sampled_points_i:
|
| 372 |
+
new_sampled_points_i[random_pertube_map[sampled_point]] = sampled_points_i[sampled_point]
|
| 373 |
+
for pos1, pos2 in sampled_edges_i:
|
| 374 |
+
new_sampled_edges_i.append((random_pertube_map[pos1], random_pertube_map[pos2]))
|
| 375 |
+
return new_sampled_points_i, new_sampled_edges_i
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
def draw_given_layers_on_tensors_random_region(given_layers, tensors, graphs):
|
| 379 |
+
"""draw 9*9 yellow squares and width 2 blue lines"""
|
| 380 |
+
tensors_list = []
|
| 381 |
+
unnormalized_list = []
|
| 382 |
+
for i in range(len(given_layers)):
|
| 383 |
+
temp_tensor = tensors[i]
|
| 384 |
+
|
| 385 |
+
temp_tensor_0 = (temp_tensor[0] * std[0] + mean[0]) * 255
|
| 386 |
+
temp_tensor_1 = (temp_tensor[1] * std[1] + mean[1]) * 255
|
| 387 |
+
temp_tensor_2 = (temp_tensor[2] * std[2] + mean[2]) * 255
|
| 388 |
+
|
| 389 |
+
rectangle_radius = 5
|
| 390 |
+
|
| 391 |
+
# end sign
|
| 392 |
+
endsign = (505, 505)
|
| 393 |
+
valid_violet_endsign_up = endsign[1] - rectangle_radius
|
| 394 |
+
valid_violet_endsign_down = endsign[1] + rectangle_radius
|
| 395 |
+
valid_violet_endsign_left = endsign[0] - rectangle_radius
|
| 396 |
+
valid_violet_endsign_right = endsign[0] + rectangle_radius
|
| 397 |
+
temp_tensor_0[
|
| 398 |
+
valid_violet_endsign_up : valid_violet_endsign_down + 1,
|
| 399 |
+
valid_violet_endsign_left : valid_violet_endsign_right + 1,
|
| 400 |
+
] = 255
|
| 401 |
+
temp_tensor_1[
|
| 402 |
+
valid_violet_endsign_up : valid_violet_endsign_down + 1,
|
| 403 |
+
valid_violet_endsign_left : valid_violet_endsign_right + 1,
|
| 404 |
+
] = 0
|
| 405 |
+
temp_tensor_2[
|
| 406 |
+
valid_violet_endsign_up : valid_violet_endsign_down + 1,
|
| 407 |
+
valid_violet_endsign_left : valid_violet_endsign_right + 1,
|
| 408 |
+
] = 255
|
| 409 |
+
|
| 410 |
+
sampled_points_i, sampled_edges_i = given_layers[i]
|
| 411 |
+
sampled_points_i, sampled_edges_i = random_pertubation(sampled_points_i, sampled_edges_i)
|
| 412 |
+
|
| 413 |
+
given_points = [k for k, v in sampled_points_i.items() if v == 1]
|
| 414 |
+
|
| 415 |
+
for j, pos in enumerate(given_points):
|
| 416 |
+
valid_yellow_pos_up = int(pos[1] - rectangle_radius) if (pos[1] - rectangle_radius) >= 0 else 0
|
| 417 |
+
valid_yellow_pos_down = (
|
| 418 |
+
int(pos[1] + rectangle_radius)
|
| 419 |
+
if (pos[1] + rectangle_radius) < temp_tensor.shape[2]
|
| 420 |
+
else temp_tensor.shape[2] - 1
|
| 421 |
+
)
|
| 422 |
+
valid_yellow_pos_left = int(pos[0] - rectangle_radius) if (pos[0] - rectangle_radius) >= 0 else 0
|
| 423 |
+
valid_yellow_pos_right = (
|
| 424 |
+
int(pos[0] + rectangle_radius)
|
| 425 |
+
if (pos[0] + rectangle_radius) < temp_tensor.shape[1]
|
| 426 |
+
else temp_tensor.shape[1] - 1
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
temp_tensor_0[
|
| 430 |
+
valid_yellow_pos_up : valid_yellow_pos_down + 1, valid_yellow_pos_left : valid_yellow_pos_right + 1
|
| 431 |
+
] = 255
|
| 432 |
+
temp_tensor_1[
|
| 433 |
+
valid_yellow_pos_up : valid_yellow_pos_down + 1, valid_yellow_pos_left : valid_yellow_pos_right + 1
|
| 434 |
+
] = 255
|
| 435 |
+
temp_tensor_2[
|
| 436 |
+
valid_yellow_pos_up : valid_yellow_pos_down + 1, valid_yellow_pos_left : valid_yellow_pos_right + 1
|
| 437 |
+
] = 0
|
| 438 |
+
|
| 439 |
+
# draw blue lines
|
| 440 |
+
line_width = 2
|
| 441 |
+
for edge in sampled_edges_i:
|
| 442 |
+
pos1 = (int(edge[0][0]), int(edge[0][1]))
|
| 443 |
+
pos2 = (int(edge[1][0]), int(edge[1][1]))
|
| 444 |
+
if abs(pos1[0] - pos2[0]) < abs(pos1[1] - pos2[1]):
|
| 445 |
+
if pos1[1] > pos2[1]:
|
| 446 |
+
temp_tensor_0[
|
| 447 |
+
pos2[1] : pos1[1] + 1,
|
| 448 |
+
int((pos1[0] + pos2[0]) / 2) - int(line_width / 2) : int((pos1[0] + pos2[0]) / 2)
|
| 449 |
+
+ int(line_width / 2)
|
| 450 |
+
+ 1,
|
| 451 |
+
] = 0
|
| 452 |
+
temp_tensor_1[
|
| 453 |
+
pos2[1] : pos1[1] + 1,
|
| 454 |
+
int((pos1[0] + pos2[0]) / 2) - int(line_width / 2) : int((pos1[0] + pos2[0]) / 2)
|
| 455 |
+
+ int(line_width / 2)
|
| 456 |
+
+ 1,
|
| 457 |
+
] = 0
|
| 458 |
+
temp_tensor_2[
|
| 459 |
+
pos2[1] : pos1[1] + 1,
|
| 460 |
+
int((pos1[0] + pos2[0]) / 2) - int(line_width / 2) : int((pos1[0] + pos2[0]) / 2)
|
| 461 |
+
+ int(line_width / 2)
|
| 462 |
+
+ 1,
|
| 463 |
+
] = 255
|
| 464 |
+
else:
|
| 465 |
+
temp_tensor_0[
|
| 466 |
+
pos1[1] : pos2[1] + 1,
|
| 467 |
+
int((pos2[0] + pos1[0]) / 2) - int(line_width / 2) : int((pos2[0] + pos1[0]) / 2)
|
| 468 |
+
+ int(line_width / 2)
|
| 469 |
+
+ 1,
|
| 470 |
+
] = 0
|
| 471 |
+
temp_tensor_1[
|
| 472 |
+
pos1[1] : pos2[1] + 1,
|
| 473 |
+
int((pos2[0] + pos1[0]) / 2) - int(line_width / 2) : int((pos2[0] + pos1[0]) / 2)
|
| 474 |
+
+ int(line_width / 2)
|
| 475 |
+
+ 1,
|
| 476 |
+
] = 0
|
| 477 |
+
temp_tensor_2[
|
| 478 |
+
pos1[1] : pos2[1] + 1,
|
| 479 |
+
int((pos2[0] + pos1[0]) / 2) - int(line_width / 2) : int((pos2[0] + pos1[0]) / 2)
|
| 480 |
+
+ int(line_width / 2)
|
| 481 |
+
+ 1,
|
| 482 |
+
] = 255
|
| 483 |
+
else:
|
| 484 |
+
if pos1[0] > pos2[0]:
|
| 485 |
+
temp_tensor_0[
|
| 486 |
+
int((pos1[1] + pos2[1]) / 2) - int(line_width / 2) : int((pos1[1] + pos2[1]) / 2)
|
| 487 |
+
+ int(line_width / 2)
|
| 488 |
+
+ 1,
|
| 489 |
+
pos2[0] : pos1[0] + 1,
|
| 490 |
+
] = 0
|
| 491 |
+
temp_tensor_1[
|
| 492 |
+
int((pos1[1] + pos2[1]) / 2) - int(line_width / 2) : int((pos1[1] + pos2[1]) / 2)
|
| 493 |
+
+ int(line_width / 2)
|
| 494 |
+
+ 1,
|
| 495 |
+
pos2[0] : pos1[0] + 1,
|
| 496 |
+
] = 0
|
| 497 |
+
temp_tensor_2[
|
| 498 |
+
int((pos1[1] + pos2[1]) / 2) - int(line_width / 2) : int((pos1[1] + pos2[1]) / 2)
|
| 499 |
+
+ int(line_width / 2)
|
| 500 |
+
+ 1,
|
| 501 |
+
pos2[0] : pos1[0] + 1,
|
| 502 |
+
] = 255
|
| 503 |
+
else:
|
| 504 |
+
temp_tensor_0[
|
| 505 |
+
int((pos2[1] + pos1[1]) / 2) - int(line_width / 2) : int((pos2[1] + pos1[1]) / 2)
|
| 506 |
+
+ int(line_width / 2)
|
| 507 |
+
+ 1,
|
| 508 |
+
pos1[0] : pos2[0] + 1,
|
| 509 |
+
] = 0
|
| 510 |
+
temp_tensor_1[
|
| 511 |
+
int((pos2[1] + pos1[1]) / 2) - int(line_width / 2) : int((pos2[1] + pos1[1]) / 2)
|
| 512 |
+
+ int(line_width / 2)
|
| 513 |
+
+ 1,
|
| 514 |
+
pos1[0] : pos2[0] + 1,
|
| 515 |
+
] = 0
|
| 516 |
+
temp_tensor_2[
|
| 517 |
+
int((pos2[1] + pos1[1]) / 2) - int(line_width / 2) : int((pos2[1] + pos1[1]) / 2)
|
| 518 |
+
+ int(line_width / 2)
|
| 519 |
+
+ 1,
|
| 520 |
+
pos1[0] : pos2[0] + 1,
|
| 521 |
+
] = 255
|
| 522 |
+
|
| 523 |
+
unnormalized = torch.stack((temp_tensor_0, temp_tensor_1, temp_tensor_2), dim=0)
|
| 524 |
+
unnormalized_list.append(unnormalized)
|
| 525 |
+
|
| 526 |
+
temp_tensor_0_renorm = ((temp_tensor_0 / 255) - mean[0]) / std[0]
|
| 527 |
+
temp_tensor_1_renorm = ((temp_tensor_1 / 255) - mean[1]) / std[1]
|
| 528 |
+
temp_tensor_2_renorm = ((temp_tensor_2 / 255) - mean[2]) / std[2]
|
| 529 |
+
|
| 530 |
+
temp_tensor = torch.stack([temp_tensor_0_renorm, temp_tensor_1_renorm, temp_tensor_2_renorm], dim=0)
|
| 531 |
+
|
| 532 |
+
tensors_list.append(temp_tensor)
|
| 533 |
+
|
| 534 |
+
return torch.stack(tensors_list, dim=0), torch.stack(unnormalized_list, dim=0)
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
def initialize_tensors(tensors):
|
| 538 |
+
tensors_list = []
|
| 539 |
+
unnormalized_list = []
|
| 540 |
+
for i in range(len(tensors)):
|
| 541 |
+
temp_tensor = tensors[i]
|
| 542 |
+
|
| 543 |
+
temp_tensor_0 = (temp_tensor[0] * std[0] + mean[0]) * 255
|
| 544 |
+
temp_tensor_1 = (temp_tensor[1] * std[1] + mean[1]) * 255
|
| 545 |
+
temp_tensor_2 = (temp_tensor[2] * std[2] + mean[2]) * 255
|
| 546 |
+
|
| 547 |
+
rectangle_radius = 5 # 4+1+4=9
|
| 548 |
+
|
| 549 |
+
# end sign (when predict this, AR iteration terminates)
|
| 550 |
+
endsign = (505, 505)
|
| 551 |
+
valid_violet_endsign_up = endsign[1] - rectangle_radius
|
| 552 |
+
valid_violet_endsign_down = endsign[1] + rectangle_radius
|
| 553 |
+
valid_violet_endsign_left = endsign[0] - rectangle_radius
|
| 554 |
+
valid_violet_endsign_right = endsign[0] + rectangle_radius
|
| 555 |
+
temp_tensor_0[
|
| 556 |
+
valid_violet_endsign_up : valid_violet_endsign_down + 1,
|
| 557 |
+
valid_violet_endsign_left : valid_violet_endsign_right + 1,
|
| 558 |
+
] = 255
|
| 559 |
+
temp_tensor_1[
|
| 560 |
+
valid_violet_endsign_up : valid_violet_endsign_down + 1,
|
| 561 |
+
valid_violet_endsign_left : valid_violet_endsign_right + 1,
|
| 562 |
+
] = 0
|
| 563 |
+
temp_tensor_2[
|
| 564 |
+
valid_violet_endsign_up : valid_violet_endsign_down + 1,
|
| 565 |
+
valid_violet_endsign_left : valid_violet_endsign_right + 1,
|
| 566 |
+
] = 255
|
| 567 |
+
|
| 568 |
+
unnormalized = torch.stack((temp_tensor_0, temp_tensor_1, temp_tensor_2), dim=0)
|
| 569 |
+
unnormalized_list.append(unnormalized)
|
| 570 |
+
|
| 571 |
+
temp_tensor_0_renorm = ((temp_tensor_0 / 255) - mean[0]) / std[0]
|
| 572 |
+
temp_tensor_1_renorm = ((temp_tensor_1 / 255) - mean[1]) / std[1]
|
| 573 |
+
temp_tensor_2_renorm = ((temp_tensor_2 / 255) - mean[2]) / std[2]
|
| 574 |
+
|
| 575 |
+
temp_tensor = torch.stack([temp_tensor_0_renorm, temp_tensor_1_renorm, temp_tensor_2_renorm], dim=0)
|
| 576 |
+
|
| 577 |
+
tensors_list.append(temp_tensor)
|
| 578 |
+
|
| 579 |
+
return torch.stack(tensors_list, dim=0), torch.stack(unnormalized_list, dim=0)
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
def l1_dist(pos1, pos2):
|
| 583 |
+
return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1])
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
def delete_graphs(targets):
|
| 587 |
+
no_graph_targets = []
|
| 588 |
+
for target in targets:
|
| 589 |
+
target_ = copy.deepcopy(target)
|
| 590 |
+
del target_["graph"]
|
| 591 |
+
no_graph_targets.append(target_)
|
| 592 |
+
return no_graph_targets
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
def delete_graphs_and_unnormpoints(targets):
|
| 596 |
+
no_graph_targets = []
|
| 597 |
+
for target in targets:
|
| 598 |
+
target_ = copy.deepcopy(target)
|
| 599 |
+
del target_["graph"]
|
| 600 |
+
del target_["unnormalized_points"]
|
| 601 |
+
no_graph_targets.append(target_)
|
| 602 |
+
return no_graph_targets
|
| 603 |
+
|
| 604 |
+
|
| 605 |
+
def get_remove_point(this_preds, dist_threshold):
|
| 606 |
+
for point1 in this_preds:
|
| 607 |
+
for point2 in this_preds:
|
| 608 |
+
# if point1 != point2:
|
| 609 |
+
if not (
|
| 610 |
+
(point1["points"].tolist()[0] == point2["points"].tolist()[0])
|
| 611 |
+
and (point1["points"].tolist()[1] == point2["points"].tolist()[1])
|
| 612 |
+
):
|
| 613 |
+
dist_chebyshev = max(
|
| 614 |
+
abs(point1["points"].tolist()[0] - point2["points"].tolist()[0]),
|
| 615 |
+
abs(point1["points"].tolist()[1] - point2["points"].tolist()[1]),
|
| 616 |
+
)
|
| 617 |
+
if dist_chebyshev <= dist_threshold:
|
| 618 |
+
point1_confidence = point1["scores"].item()
|
| 619 |
+
point2_confidence = point2["scores"].item()
|
| 620 |
+
if point1_confidence < point2_confidence:
|
| 621 |
+
return point1
|
| 622 |
+
elif point2_confidence < point1_confidence:
|
| 623 |
+
return point2
|
| 624 |
+
else:
|
| 625 |
+
return [point1, point2][random.randint(0, 1)]
|
| 626 |
+
return None
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
def point_inside(point, points_list):
|
| 630 |
+
point1 = tuple(point["points"].tolist())
|
| 631 |
+
for point_i in points_list:
|
| 632 |
+
point1_i = tuple(point_i["points"].tolist())
|
| 633 |
+
if point1 == point1_i:
|
| 634 |
+
return True
|
| 635 |
+
return False
|
| 636 |
+
|
| 637 |
+
|
| 638 |
+
def remove_points(need_to_remove_in_last_edges, this_preds):
|
| 639 |
+
result = []
|
| 640 |
+
for this_pred in this_preds:
|
| 641 |
+
if not point_inside(this_pred, need_to_remove_in_last_edges):
|
| 642 |
+
result.append(this_pred)
|
| 643 |
+
return result
|
| 644 |
+
|
| 645 |
+
|
| 646 |
+
def nms(this_preds):
|
| 647 |
+
if len(this_preds) <= 1:
|
| 648 |
+
return this_preds
|
| 649 |
+
else:
|
| 650 |
+
dist_threshold = 5
|
| 651 |
+
while True:
|
| 652 |
+
remove_point = get_remove_point(this_preds, dist_threshold)
|
| 653 |
+
if remove_point is None:
|
| 654 |
+
break
|
| 655 |
+
else:
|
| 656 |
+
# this_preds.remove(remove_point)
|
| 657 |
+
this_preds = remove_points([remove_point], this_preds)
|
| 658 |
+
|
| 659 |
+
return this_preds
|
| 660 |
+
|
| 661 |
+
|
| 662 |
+
def nms_givenpoints(this_preds, preds):
|
| 663 |
+
if len(this_preds) == 0:
|
| 664 |
+
return this_preds
|
| 665 |
+
else:
|
| 666 |
+
all_given_points = []
|
| 667 |
+
for given_points, given_last_edges, given_this_edges in preds:
|
| 668 |
+
all_given_points.extend(given_points)
|
| 669 |
+
if len(all_given_points) == 0:
|
| 670 |
+
return this_preds
|
| 671 |
+
this_preds_copy = copy.deepcopy(this_preds)
|
| 672 |
+
dist_threshold = 5
|
| 673 |
+
for this_pred in this_preds_copy:
|
| 674 |
+
for given_point in all_given_points:
|
| 675 |
+
this_pred_pos = tuple(this_pred["points"].tolist())
|
| 676 |
+
given_point_pos = tuple(given_point["points"].tolist())
|
| 677 |
+
dist_chebyshev = max(
|
| 678 |
+
abs(this_pred_pos[0] - given_point_pos[0]), abs(this_pred_pos[1] - given_point_pos[1])
|
| 679 |
+
)
|
| 680 |
+
if dist_chebyshev <= dist_threshold:
|
| 681 |
+
this_preds = remove_points([this_pred], this_preds)
|
| 682 |
+
break
|
| 683 |
+
return this_preds
|
| 684 |
+
|
| 685 |
+
|
| 686 |
+
def random_keep(this_preds):
|
| 687 |
+
if len(this_preds) <= 1:
|
| 688 |
+
return this_preds
|
| 689 |
+
else:
|
| 690 |
+
while True:
|
| 691 |
+
random_keep_this_preds = []
|
| 692 |
+
for point in this_preds:
|
| 693 |
+
# is_keep = random.random() < point['scores'].item()
|
| 694 |
+
is_keep = random.random() < 1.01
|
| 695 |
+
# is_keep = random.random() < 0.5
|
| 696 |
+
if is_keep:
|
| 697 |
+
random_keep_this_preds.append(point)
|
| 698 |
+
if len(random_keep_this_preds) > 0:
|
| 699 |
+
return random_keep_this_preds
|
| 700 |
+
|
| 701 |
+
|
| 702 |
+
def is_stop(this_preds):
|
| 703 |
+
if len(this_preds) == 0:
|
| 704 |
+
return 1 # stop
|
| 705 |
+
elif (len(this_preds) >= 1) and (16 in [p["edges"].item() for p in this_preds]):
|
| 706 |
+
return 2 # normally terminate
|
| 707 |
+
else:
|
| 708 |
+
return 0 # not stop
|
| 709 |
+
|
| 710 |
+
|
| 711 |
+
def draw_preds_on_tensors(preds, tensors):
|
| 712 |
+
tensors_list = []
|
| 713 |
+
unnormalized_list = []
|
| 714 |
+
|
| 715 |
+
for i in range(len(tensors)):
|
| 716 |
+
temp_tensor = tensors[i]
|
| 717 |
+
|
| 718 |
+
temp_tensor_0 = (temp_tensor[0] * std[0] + mean[0]) * 255
|
| 719 |
+
temp_tensor_1 = (temp_tensor[1] * std[1] + mean[1]) * 255
|
| 720 |
+
temp_tensor_2 = (temp_tensor[2] * std[2] + mean[2]) * 255
|
| 721 |
+
|
| 722 |
+
rectangle_radius = 5
|
| 723 |
+
|
| 724 |
+
this_preds, last_edges, this_edges = preds[-1]
|
| 725 |
+
for this_pred in this_preds:
|
| 726 |
+
point = tuple([int(_) for _ in this_pred["points"].tolist()])
|
| 727 |
+
up = point[1] - rectangle_radius
|
| 728 |
+
down = point[1] + rectangle_radius
|
| 729 |
+
left = point[0] - rectangle_radius
|
| 730 |
+
right = point[0] + rectangle_radius
|
| 731 |
+
temp_tensor_0[up : down + 1, left : right + 1] = 255
|
| 732 |
+
temp_tensor_1[up : down + 1, left : right + 1] = 255
|
| 733 |
+
temp_tensor_2[up : down + 1, left : right + 1] = 0
|
| 734 |
+
line_width = 2
|
| 735 |
+
for last_edge in last_edges:
|
| 736 |
+
pos1 = tuple([int(_) for _ in last_edge[0]["points"].tolist()])
|
| 737 |
+
pos2 = tuple([int(_) for _ in last_edge[1]["points"].tolist()])
|
| 738 |
+
if abs(pos1[0] - pos2[0]) < abs(pos1[1] - pos2[1]):
|
| 739 |
+
if pos1[1] > pos2[1]:
|
| 740 |
+
temp_tensor_0[
|
| 741 |
+
pos2[1] : pos1[1] + 1,
|
| 742 |
+
int((pos1[0] + pos2[0]) / 2) - int(line_width / 2) : int((pos1[0] + pos2[0]) / 2)
|
| 743 |
+
+ int(line_width / 2)
|
| 744 |
+
+ 1,
|
| 745 |
+
] = 0
|
| 746 |
+
temp_tensor_1[
|
| 747 |
+
pos2[1] : pos1[1] + 1,
|
| 748 |
+
int((pos1[0] + pos2[0]) / 2) - int(line_width / 2) : int((pos1[0] + pos2[0]) / 2)
|
| 749 |
+
+ int(line_width / 2)
|
| 750 |
+
+ 1,
|
| 751 |
+
] = 0
|
| 752 |
+
temp_tensor_2[
|
| 753 |
+
pos2[1] : pos1[1] + 1,
|
| 754 |
+
int((pos1[0] + pos2[0]) / 2) - int(line_width / 2) : int((pos1[0] + pos2[0]) / 2)
|
| 755 |
+
+ int(line_width / 2)
|
| 756 |
+
+ 1,
|
| 757 |
+
] = 255
|
| 758 |
+
else:
|
| 759 |
+
temp_tensor_0[
|
| 760 |
+
pos1[1] : pos2[1] + 1,
|
| 761 |
+
int((pos2[0] + pos1[0]) / 2) - int(line_width / 2) : int((pos2[0] + pos1[0]) / 2)
|
| 762 |
+
+ int(line_width / 2)
|
| 763 |
+
+ 1,
|
| 764 |
+
] = 0
|
| 765 |
+
temp_tensor_1[
|
| 766 |
+
pos1[1] : pos2[1] + 1,
|
| 767 |
+
int((pos2[0] + pos1[0]) / 2) - int(line_width / 2) : int((pos2[0] + pos1[0]) / 2)
|
| 768 |
+
+ int(line_width / 2)
|
| 769 |
+
+ 1,
|
| 770 |
+
] = 0
|
| 771 |
+
temp_tensor_2[
|
| 772 |
+
pos1[1] : pos2[1] + 1,
|
| 773 |
+
int((pos2[0] + pos1[0]) / 2) - int(line_width / 2) : int((pos2[0] + pos1[0]) / 2)
|
| 774 |
+
+ int(line_width / 2)
|
| 775 |
+
+ 1,
|
| 776 |
+
] = 255
|
| 777 |
+
else:
|
| 778 |
+
if pos1[0] > pos2[0]:
|
| 779 |
+
temp_tensor_0[
|
| 780 |
+
int((pos1[1] + pos2[1]) / 2) - int(line_width / 2) : int((pos1[1] + pos2[1]) / 2)
|
| 781 |
+
+ int(line_width / 2)
|
| 782 |
+
+ 1,
|
| 783 |
+
pos2[0] : pos1[0] + 1,
|
| 784 |
+
] = 0
|
| 785 |
+
temp_tensor_1[
|
| 786 |
+
int((pos1[1] + pos2[1]) / 2) - int(line_width / 2) : int((pos1[1] + pos2[1]) / 2)
|
| 787 |
+
+ int(line_width / 2)
|
| 788 |
+
+ 1,
|
| 789 |
+
pos2[0] : pos1[0] + 1,
|
| 790 |
+
] = 0
|
| 791 |
+
temp_tensor_2[
|
| 792 |
+
int((pos1[1] + pos2[1]) / 2) - int(line_width / 2) : int((pos1[1] + pos2[1]) / 2)
|
| 793 |
+
+ int(line_width / 2)
|
| 794 |
+
+ 1,
|
| 795 |
+
pos2[0] : pos1[0] + 1,
|
| 796 |
+
] = 255
|
| 797 |
+
else:
|
| 798 |
+
temp_tensor_0[
|
| 799 |
+
int((pos2[1] + pos1[1]) / 2) - int(line_width / 2) : int((pos2[1] + pos1[1]) / 2)
|
| 800 |
+
+ int(line_width / 2)
|
| 801 |
+
+ 1,
|
| 802 |
+
pos1[0] : pos2[0] + 1,
|
| 803 |
+
] = 0
|
| 804 |
+
temp_tensor_1[
|
| 805 |
+
int((pos2[1] + pos1[1]) / 2) - int(line_width / 2) : int((pos2[1] + pos1[1]) / 2)
|
| 806 |
+
+ int(line_width / 2)
|
| 807 |
+
+ 1,
|
| 808 |
+
pos1[0] : pos2[0] + 1,
|
| 809 |
+
] = 0
|
| 810 |
+
temp_tensor_2[
|
| 811 |
+
int((pos2[1] + pos1[1]) / 2) - int(line_width / 2) : int((pos2[1] + pos1[1]) / 2)
|
| 812 |
+
+ int(line_width / 2)
|
| 813 |
+
+ 1,
|
| 814 |
+
pos1[0] : pos2[0] + 1,
|
| 815 |
+
] = 255
|
| 816 |
+
for this_edge in this_edges:
|
| 817 |
+
pos1 = tuple([int(_) for _ in this_edge[0]["points"].tolist()])
|
| 818 |
+
pos2 = tuple([int(_) for _ in this_edge[1]["points"].tolist()])
|
| 819 |
+
if abs(pos1[0] - pos2[0]) < abs(pos1[1] - pos2[1]):
|
| 820 |
+
if pos1[1] > pos2[1]:
|
| 821 |
+
temp_tensor_0[
|
| 822 |
+
pos2[1] : pos1[1] + 1,
|
| 823 |
+
int((pos1[0] + pos2[0]) / 2) - int(line_width / 2) : int((pos1[0] + pos2[0]) / 2)
|
| 824 |
+
+ int(line_width / 2)
|
| 825 |
+
+ 1,
|
| 826 |
+
] = 0
|
| 827 |
+
temp_tensor_1[
|
| 828 |
+
pos2[1] : pos1[1] + 1,
|
| 829 |
+
int((pos1[0] + pos2[0]) / 2) - int(line_width / 2) : int((pos1[0] + pos2[0]) / 2)
|
| 830 |
+
+ int(line_width / 2)
|
| 831 |
+
+ 1,
|
| 832 |
+
] = 0
|
| 833 |
+
temp_tensor_2[
|
| 834 |
+
pos2[1] : pos1[1] + 1,
|
| 835 |
+
int((pos1[0] + pos2[0]) / 2) - int(line_width / 2) : int((pos1[0] + pos2[0]) / 2)
|
| 836 |
+
+ int(line_width / 2)
|
| 837 |
+
+ 1,
|
| 838 |
+
] = 255
|
| 839 |
+
else:
|
| 840 |
+
temp_tensor_0[
|
| 841 |
+
pos1[1] : pos2[1] + 1,
|
| 842 |
+
int((pos2[0] + pos1[0]) / 2) - int(line_width / 2) : int((pos2[0] + pos1[0]) / 2)
|
| 843 |
+
+ int(line_width / 2)
|
| 844 |
+
+ 1,
|
| 845 |
+
] = 0
|
| 846 |
+
temp_tensor_1[
|
| 847 |
+
pos1[1] : pos2[1] + 1,
|
| 848 |
+
int((pos2[0] + pos1[0]) / 2) - int(line_width / 2) : int((pos2[0] + pos1[0]) / 2)
|
| 849 |
+
+ int(line_width / 2)
|
| 850 |
+
+ 1,
|
| 851 |
+
] = 0
|
| 852 |
+
temp_tensor_2[
|
| 853 |
+
pos1[1] : pos2[1] + 1,
|
| 854 |
+
int((pos2[0] + pos1[0]) / 2) - int(line_width / 2) : int((pos2[0] + pos1[0]) / 2)
|
| 855 |
+
+ int(line_width / 2)
|
| 856 |
+
+ 1,
|
| 857 |
+
] = 255
|
| 858 |
+
else:
|
| 859 |
+
if pos1[0] > pos2[0]:
|
| 860 |
+
temp_tensor_0[
|
| 861 |
+
int((pos1[1] + pos2[1]) / 2) - int(line_width / 2) : int((pos1[1] + pos2[1]) / 2)
|
| 862 |
+
+ int(line_width / 2)
|
| 863 |
+
+ 1,
|
| 864 |
+
pos2[0] : pos1[0] + 1,
|
| 865 |
+
] = 0
|
| 866 |
+
temp_tensor_1[
|
| 867 |
+
int((pos1[1] + pos2[1]) / 2) - int(line_width / 2) : int((pos1[1] + pos2[1]) / 2)
|
| 868 |
+
+ int(line_width / 2)
|
| 869 |
+
+ 1,
|
| 870 |
+
pos2[0] : pos1[0] + 1,
|
| 871 |
+
] = 0
|
| 872 |
+
temp_tensor_2[
|
| 873 |
+
int((pos1[1] + pos2[1]) / 2) - int(line_width / 2) : int((pos1[1] + pos2[1]) / 2)
|
| 874 |
+
+ int(line_width / 2)
|
| 875 |
+
+ 1,
|
| 876 |
+
pos2[0] : pos1[0] + 1,
|
| 877 |
+
] = 255
|
| 878 |
+
else:
|
| 879 |
+
temp_tensor_0[
|
| 880 |
+
int((pos2[1] + pos1[1]) / 2) - int(line_width / 2) : int((pos2[1] + pos1[1]) / 2)
|
| 881 |
+
+ int(line_width / 2)
|
| 882 |
+
+ 1,
|
| 883 |
+
pos1[0] : pos2[0] + 1,
|
| 884 |
+
] = 0
|
| 885 |
+
temp_tensor_1[
|
| 886 |
+
int((pos2[1] + pos1[1]) / 2) - int(line_width / 2) : int((pos2[1] + pos1[1]) / 2)
|
| 887 |
+
+ int(line_width / 2)
|
| 888 |
+
+ 1,
|
| 889 |
+
pos1[0] : pos2[0] + 1,
|
| 890 |
+
] = 0
|
| 891 |
+
temp_tensor_2[
|
| 892 |
+
int((pos2[1] + pos1[1]) / 2) - int(line_width / 2) : int((pos2[1] + pos1[1]) / 2)
|
| 893 |
+
+ int(line_width / 2)
|
| 894 |
+
+ 1,
|
| 895 |
+
pos1[0] : pos2[0] + 1,
|
| 896 |
+
] = 255
|
| 897 |
+
|
| 898 |
+
unnormalized = torch.stack((temp_tensor_0, temp_tensor_1, temp_tensor_2), dim=0)
|
| 899 |
+
unnormalized_list.append(unnormalized)
|
| 900 |
+
|
| 901 |
+
temp_tensor_0_renorm = ((temp_tensor_0 / 255) - mean[0]) / std[0]
|
| 902 |
+
temp_tensor_1_renorm = ((temp_tensor_1 / 255) - mean[1]) / std[1]
|
| 903 |
+
temp_tensor_2_renorm = ((temp_tensor_2 / 255) - mean[2]) / std[2]
|
| 904 |
+
|
| 905 |
+
temp_tensor = torch.stack([temp_tensor_0_renorm, temp_tensor_1_renorm, temp_tensor_2_renorm], dim=0)
|
| 906 |
+
|
| 907 |
+
tensors_list.append(temp_tensor)
|
| 908 |
+
|
| 909 |
+
return torch.stack(tensors_list, dim=0), torch.stack(unnormalized_list, dim=0)
|
| 910 |
+
|
| 911 |
+
|
| 912 |
+
def edge_inside(edge, edges_list):
|
| 913 |
+
edge_point1 = tuple(edge[0]["points"].tolist())
|
| 914 |
+
edge_point2 = tuple(edge[1]["points"].tolist())
|
| 915 |
+
for edge_i in edges_list:
|
| 916 |
+
edge_i_point1 = tuple(edge_i[0]["points"].tolist())
|
| 917 |
+
edge_i_point2 = tuple(edge_i[1]["points"].tolist())
|
| 918 |
+
if ((edge_point1 == edge_i_point1) and (edge_point2 == edge_i_point2)) or (
|
| 919 |
+
(edge_point1 == edge_i_point2) and (edge_point2 == edge_i_point1)
|
| 920 |
+
):
|
| 921 |
+
return True
|
| 922 |
+
return False
|
| 923 |
+
|
| 924 |
+
|
| 925 |
+
def remove_edge(edge, edges_list):
|
| 926 |
+
result = []
|
| 927 |
+
edge_point1 = tuple(edge[0]["points"].tolist())
|
| 928 |
+
edge_point2 = tuple(edge[1]["points"].tolist())
|
| 929 |
+
for edge_i in edges_list:
|
| 930 |
+
edge_i_point1 = tuple(edge_i[0]["points"].tolist())
|
| 931 |
+
edge_i_point2 = tuple(edge_i[1]["points"].tolist())
|
| 932 |
+
if (edge_point1 == edge_i_point1) and (edge_point2 == edge_i_point2):
|
| 933 |
+
pass
|
| 934 |
+
else:
|
| 935 |
+
result.append(edge_i)
|
| 936 |
+
return result
|
| 937 |
+
|
| 938 |
+
|
| 939 |
+
def get_edges_amount(preds):
|
| 940 |
+
count = 0
|
| 941 |
+
for this_preds, last_edges, this_edges in preds:
|
| 942 |
+
count += len(last_edges)
|
| 943 |
+
count += len(this_edges)
|
| 944 |
+
return count
|
| 945 |
+
|
| 946 |
+
|
| 947 |
+
def get_reserve_preds(results, keep_confidence_threshold, targets):
|
| 948 |
+
reserve_preds = []
|
| 949 |
+
|
| 950 |
+
valid_label_indices_edges = torch.where(results["edges"] != 0)[0]
|
| 951 |
+
valid_label_indices_scores = torch.where(results["scores"] <= keep_confidence_threshold)[0]
|
| 952 |
+
valid_label_indices = torch.tensor(
|
| 953 |
+
list(set(valid_label_indices_edges.tolist()).intersection(set(valid_label_indices_scores.tolist()))),
|
| 954 |
+
dtype=valid_label_indices_edges.dtype,
|
| 955 |
+
device=valid_label_indices_edges.device,
|
| 956 |
+
)
|
| 957 |
+
for valid_label_indice in valid_label_indices:
|
| 958 |
+
valid_results_i = {}
|
| 959 |
+
valid_results_i["scores"] = results["scores"][valid_label_indice]
|
| 960 |
+
valid_results_i["points"] = results["points"][valid_label_indice]
|
| 961 |
+
valid_results_i["last_edges"] = results["last_edges"][valid_label_indice]
|
| 962 |
+
valid_results_i["this_edges"] = results["this_edges"][valid_label_indice]
|
| 963 |
+
valid_results_i["edges"] = results["edges"][valid_label_indice]
|
| 964 |
+
valid_results_i["size"] = targets[0]["size"]
|
| 965 |
+
reserve_preds.append(valid_results_i)
|
| 966 |
+
return reserve_preds
|
data_preprocess/raster2graph/util/edges_utils.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
edges = {
|
| 2 |
+
0: "0000",
|
| 3 |
+
1: "0001",
|
| 4 |
+
2: "0010",
|
| 5 |
+
3: "0011",
|
| 6 |
+
4: "0100",
|
| 7 |
+
5: "0110",
|
| 8 |
+
6: "0111",
|
| 9 |
+
7: "1000",
|
| 10 |
+
8: "1001",
|
| 11 |
+
9: "1011",
|
| 12 |
+
10: "1100",
|
| 13 |
+
11: "1101",
|
| 14 |
+
12: "1110",
|
| 15 |
+
13: "1111",
|
| 16 |
+
14: "0101",
|
| 17 |
+
15: "1010",
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_edges_alldirections(edges_class):
|
| 22 |
+
return edges[edges_class]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
edges_rev = {
|
| 26 |
+
"0000": 0,
|
| 27 |
+
"0001": 1,
|
| 28 |
+
"0010": 2,
|
| 29 |
+
"0011": 3,
|
| 30 |
+
"0100": 4,
|
| 31 |
+
"0110": 5,
|
| 32 |
+
"0111": 6,
|
| 33 |
+
"1000": 7,
|
| 34 |
+
"1001": 8,
|
| 35 |
+
"1011": 9,
|
| 36 |
+
"1100": 10,
|
| 37 |
+
"1101": 11,
|
| 38 |
+
"1110": 12,
|
| 39 |
+
"1111": 13,
|
| 40 |
+
"0101": 14,
|
| 41 |
+
"1010": 15,
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def get_edges_alldirections_rev(edges_class_rev):
|
| 46 |
+
return edges_rev[edges_class_rev]
|
data_preprocess/raster2graph/util/geom_utils.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
from shapely.geometry import Polygon
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def poly_iou(poly1: Polygon, poly2: Polygon):
|
| 7 |
+
try:
|
| 8 |
+
intersection_area = poly1.intersection(poly2).area
|
| 9 |
+
union_area = poly1.union(poly2).area
|
| 10 |
+
return intersection_area / union_area
|
| 11 |
+
except Exception:
|
| 12 |
+
poly1 = poly1.buffer(1)
|
| 13 |
+
poly2 = poly2.buffer(1)
|
| 14 |
+
intersection_area = poly1.intersection(poly2).area
|
| 15 |
+
union_area = poly1.union(poly2).area
|
| 16 |
+
return intersection_area / union_area
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def is_clockwise_or_not(points):
|
| 20 |
+
s = 0
|
| 21 |
+
for i in range(0, len(points) - 1):
|
| 22 |
+
s += points[i][0] * points[i + 1][1] - points[i][1] * points[i + 1][0]
|
| 23 |
+
return s > 0
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def x_axis_angle(y):
|
| 27 |
+
# 以图像坐标系为准,(1,0)方向记为0度,逆时针绕一圈到360度
|
| 28 |
+
# print('-------------')
|
| 29 |
+
# print(y)
|
| 30 |
+
y_right_hand = (y[0], -y[1])
|
| 31 |
+
# print(y_right_hand)
|
| 32 |
+
|
| 33 |
+
x = (1, 0)
|
| 34 |
+
inner = x[0] * y_right_hand[0] + x[1] * y_right_hand[1]
|
| 35 |
+
# print(inner)
|
| 36 |
+
y_norm2 = (y_right_hand[0] ** 2 + y_right_hand[1] ** 2) ** 0.5
|
| 37 |
+
# print(y_norm2)
|
| 38 |
+
cosxy = inner / (y_norm2 + 1e-8)
|
| 39 |
+
# print(cosxy)
|
| 40 |
+
angle = math.acos(cosxy)
|
| 41 |
+
# print(angle, math.degrees(angle))
|
| 42 |
+
# print('-------------')
|
| 43 |
+
return math.degrees(angle) if y_right_hand[1] >= 0 else 360 - math.degrees(angle)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def get_quadrant(angle):
|
| 47 |
+
if angle[0] < angle[1]:
|
| 48 |
+
if 0 <= angle[0] < 90 and 0 <= angle[1] < 90:
|
| 49 |
+
quadrant = (angle[1] - angle[0], 0, 0, 0)
|
| 50 |
+
elif 0 <= angle[0] < 90 and 90 <= angle[1] < 180:
|
| 51 |
+
quadrant = (90 - angle[0], angle[1] - 90, 0, 0)
|
| 52 |
+
elif 0 <= angle[0] < 90 and 180 <= angle[1] < 270:
|
| 53 |
+
quadrant = (90 - angle[0], 90, angle[1] - 180, 0)
|
| 54 |
+
elif 0 <= angle[0] < 90 and 270 <= angle[1] < 360:
|
| 55 |
+
quadrant = (90 - angle[0], 90, 90, angle[1] - 270)
|
| 56 |
+
elif 90 <= angle[0] < 180 and 90 <= angle[1] < 180:
|
| 57 |
+
quadrant = (0, angle[1] - angle[0], 0, 0)
|
| 58 |
+
elif 90 <= angle[0] < 180 and 180 <= angle[1] < 270:
|
| 59 |
+
quadrant = (0, 180 - angle[0], angle[1] - 180, 0)
|
| 60 |
+
elif 90 <= angle[0] < 180 and 270 <= angle[1] < 360:
|
| 61 |
+
quadrant = (0, 180 - angle[0], 90, angle[1] - 270)
|
| 62 |
+
elif 180 <= angle[0] < 270 and 180 <= angle[1] < 270:
|
| 63 |
+
quadrant = (0, 0, angle[1] - angle[0], 0)
|
| 64 |
+
elif 180 <= angle[0] < 270 and 270 <= angle[1] < 360:
|
| 65 |
+
quadrant = (0, 0, 270 - angle[0], angle[1] - 270)
|
| 66 |
+
elif 270 <= angle[0] < 360 and 270 <= angle[1] < 360:
|
| 67 |
+
quadrant = (0, 0, 0, angle[1] - angle[0])
|
| 68 |
+
else:
|
| 69 |
+
if 0 <= angle[1] < 90 and 0 <= angle[0] < 90:
|
| 70 |
+
quadrant_ = (angle[0] - angle[1], 0, 0, 0)
|
| 71 |
+
elif 0 <= angle[1] < 90 and 90 <= angle[0] < 180:
|
| 72 |
+
quadrant_ = (90 - angle[1], angle[0] - 90, 0, 0)
|
| 73 |
+
elif 0 <= angle[1] < 90 and 180 <= angle[0] < 270:
|
| 74 |
+
quadrant_ = (90 - angle[1], 90, angle[0] - 180, 0)
|
| 75 |
+
elif 0 <= angle[1] < 90 and 270 <= angle[0] < 360:
|
| 76 |
+
quadrant_ = (90 - angle[1], 90, 90, angle[0] - 270)
|
| 77 |
+
elif 90 <= angle[1] < 180 and 90 <= angle[0] < 180:
|
| 78 |
+
quadrant_ = (0, angle[0] - angle[1], 0, 0)
|
| 79 |
+
elif 90 <= angle[1] < 180 and 180 <= angle[0] < 270:
|
| 80 |
+
quadrant_ = (0, 180 - angle[1], angle[0] - 180, 0)
|
| 81 |
+
elif 90 <= angle[1] < 180 and 270 <= angle[0] < 360:
|
| 82 |
+
quadrant_ = (0, 180 - angle[1], 90, angle[0] - 270)
|
| 83 |
+
elif 180 <= angle[1] < 270 and 180 <= angle[0] < 270:
|
| 84 |
+
quadrant_ = (0, 0, angle[0] - angle[1], 0)
|
| 85 |
+
elif 180 <= angle[1] < 270 and 270 <= angle[0] < 360:
|
| 86 |
+
quadrant_ = (0, 0, 270 - angle[1], angle[0] - 270)
|
| 87 |
+
elif 270 <= angle[1] < 360 and 270 <= angle[0] < 360:
|
| 88 |
+
quadrant_ = (0, 0, 0, angle[0] - angle[1])
|
| 89 |
+
quadrant = (90 - quadrant_[0], 90 - quadrant_[1], 90 - quadrant_[2], 90 - quadrant_[3])
|
| 90 |
+
return quadrant
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def find_which_angle_to_counterclockwise_rotate_from(t):
|
| 94 |
+
if t > 270:
|
| 95 |
+
return 630 - t
|
| 96 |
+
else:
|
| 97 |
+
return 270 - t
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def counter_degree(d):
|
| 101 |
+
if d >= 180:
|
| 102 |
+
return d - 180
|
| 103 |
+
else:
|
| 104 |
+
return d + 180
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def rotate_degree_clockwise_from_counter_degree(src_degree, dest_degree):
|
| 108 |
+
delta = src_degree - dest_degree
|
| 109 |
+
return delta if delta >= 0 else 360 + delta
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def rotate_degree_counterclockwise_from_counter_degree(src_degree, dest_degree):
|
| 113 |
+
delta = dest_degree - src_degree
|
| 114 |
+
return delta if delta >= 0 else 360 + delta
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def poly_area(points):
|
| 118 |
+
s = 0
|
| 119 |
+
points_count = len(points)
|
| 120 |
+
for i in range(points_count):
|
| 121 |
+
point = points[i]
|
| 122 |
+
point2 = points[(i + 1) % points_count]
|
| 123 |
+
s += (point[0] - point2[0]) * (point[1] + point2[1])
|
| 124 |
+
return s / 2
|
data_preprocess/raster2graph/util/graph_utils.py
ADDED
|
@@ -0,0 +1,879 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import random
|
| 3 |
+
|
| 4 |
+
import networkx as nx
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from util.geom_utils import (
|
| 8 |
+
get_quadrant,
|
| 9 |
+
is_clockwise_or_not,
|
| 10 |
+
poly_area,
|
| 11 |
+
rotate_degree_counterclockwise_from_counter_degree,
|
| 12 |
+
x_axis_angle,
|
| 13 |
+
)
|
| 14 |
+
from util.metric_utils import get_results, get_results_float_with_semantic
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def graph_to_tensor(graph):
|
| 18 |
+
t_l = []
|
| 19 |
+
for k, v in graph.items():
|
| 20 |
+
a = []
|
| 21 |
+
a.append(k)
|
| 22 |
+
a.extend(v)
|
| 23 |
+
b = [list(i) for i in a]
|
| 24 |
+
c = torch.tensor(b)
|
| 25 |
+
t_l.append(c)
|
| 26 |
+
return torch.stack(t_l, dim=0)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def tensor_to_graph(tensor):
|
| 30 |
+
gr = {}
|
| 31 |
+
for kv in tensor:
|
| 32 |
+
k = tuple([i.item() for i in kv[0]])
|
| 33 |
+
v = kv[1:5]
|
| 34 |
+
v = v.tolist()
|
| 35 |
+
v = [tuple(i) for i in v]
|
| 36 |
+
gr[k] = v
|
| 37 |
+
return gr
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def tensors_to_graphs_batch(tensors):
|
| 41 |
+
return [tensor_to_graph(ts) for ts in tensors]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def get_cycle_basis_and_semantic_deprecated(best_result):
|
| 45 |
+
output_points, output_edges = get_results_float_with_semantic(best_result)
|
| 46 |
+
d = {}
|
| 47 |
+
for output_point_index, output_point in enumerate(output_points):
|
| 48 |
+
d[output_point] = output_point_index
|
| 49 |
+
d_rev = {}
|
| 50 |
+
for output_point_index, output_point in enumerate(output_points):
|
| 51 |
+
d_rev[output_point_index] = output_point
|
| 52 |
+
es = []
|
| 53 |
+
for output_edge in output_edges:
|
| 54 |
+
es.append((d[output_edge[0]], d[output_edge[1]]))
|
| 55 |
+
|
| 56 |
+
G = nx.Graph()
|
| 57 |
+
for e in es:
|
| 58 |
+
G.add_edge(e[0], e[1])
|
| 59 |
+
|
| 60 |
+
nx.draw(G)
|
| 61 |
+
# plt.show()
|
| 62 |
+
simple_cycles = nx.cycle_basis(G)
|
| 63 |
+
|
| 64 |
+
results = []
|
| 65 |
+
|
| 66 |
+
for cycle_ind, cycle in enumerate(simple_cycles):
|
| 67 |
+
points = [d_rev[ind] for ind in cycle]
|
| 68 |
+
points.append(points[0])
|
| 69 |
+
|
| 70 |
+
is_clockwise = is_clockwise_or_not([(p[0], p[1]) for p in points])
|
| 71 |
+
if is_clockwise:
|
| 72 |
+
points.reverse()
|
| 73 |
+
|
| 74 |
+
cross_products = []
|
| 75 |
+
poses = [(p[0], p[1]) for p in points]
|
| 76 |
+
for ind in range(len(poses) - 1):
|
| 77 |
+
ei = [
|
| 78 |
+
poses[(ind + 1) % (len(poses) - 1)][0] - poses[ind][0],
|
| 79 |
+
poses[(ind + 1) % (len(poses) - 1)][1] - poses[ind][1],
|
| 80 |
+
]
|
| 81 |
+
eiplus1 = [
|
| 82 |
+
poses[(ind + 2) % (len(poses) - 1)][0] - poses[(ind + 1) % (len(poses) - 1)][0],
|
| 83 |
+
poses[(ind + 2) % (len(poses) - 1)][1] - poses[(ind + 1) % (len(poses) - 1)][1],
|
| 84 |
+
]
|
| 85 |
+
cross_products.append(np.cross(ei, eiplus1).tolist())
|
| 86 |
+
cross_products.insert(0, cross_products[-1])
|
| 87 |
+
cross_products.pop(-1)
|
| 88 |
+
|
| 89 |
+
while 0 in cross_products:
|
| 90 |
+
for point_ind, cross_product in enumerate(cross_products):
|
| 91 |
+
if cross_product == 0:
|
| 92 |
+
if point_ind == 0:
|
| 93 |
+
p0 = copy.deepcopy(points[0])
|
| 94 |
+
points[0] = (
|
| 95 |
+
p0[0] + 0.000001 * random.random() * [-1, 1][random.randint(0, 1)],
|
| 96 |
+
p0[1] + 0.000001 * random.random() * [-1, 1][random.randint(0, 1)],
|
| 97 |
+
p0[2],
|
| 98 |
+
p0[3],
|
| 99 |
+
p0[4],
|
| 100 |
+
p0[5],
|
| 101 |
+
)
|
| 102 |
+
points[-1] = copy.deepcopy(points[0])
|
| 103 |
+
else:
|
| 104 |
+
pi = copy.deepcopy(points[point_ind])
|
| 105 |
+
points[point_ind] = (
|
| 106 |
+
pi[0] + 0.000001 * random.random() * [-1, 1][random.randint(0, 1)],
|
| 107 |
+
pi[1] + 0.000001 * random.random() * [-1, 1][random.randint(0, 1)],
|
| 108 |
+
pi[2],
|
| 109 |
+
pi[3],
|
| 110 |
+
pi[4],
|
| 111 |
+
pi[5],
|
| 112 |
+
)
|
| 113 |
+
# print(points)
|
| 114 |
+
cross_products = []
|
| 115 |
+
poses = [(p[0], p[1]) for p in points]
|
| 116 |
+
for ind in range(len(poses) - 1):
|
| 117 |
+
ei = [
|
| 118 |
+
poses[(ind + 1) % (len(poses) - 1)][0] - poses[ind][0],
|
| 119 |
+
poses[(ind + 1) % (len(poses) - 1)][1] - poses[ind][1],
|
| 120 |
+
]
|
| 121 |
+
eiplus1 = [
|
| 122 |
+
poses[(ind + 2) % (len(poses) - 1)][0] - poses[(ind + 1) % (len(poses) - 1)][0],
|
| 123 |
+
poses[(ind + 2) % (len(poses) - 1)][1] - poses[(ind + 1) % (len(poses) - 1)][1],
|
| 124 |
+
]
|
| 125 |
+
cross_products.append(np.cross(ei, eiplus1).tolist())
|
| 126 |
+
cross_products.insert(0, cross_products[-1])
|
| 127 |
+
cross_products.pop(-1)
|
| 128 |
+
|
| 129 |
+
semantics = [[p[2], p[3], p[4], p[5]] for p in points]
|
| 130 |
+
|
| 131 |
+
degrees = []
|
| 132 |
+
for ind in range(len(poses) - 1):
|
| 133 |
+
ei_minus = [
|
| 134 |
+
-(poses[(ind + 1) % (len(poses) - 1)][0] - poses[ind][0]),
|
| 135 |
+
-(poses[(ind + 1) % (len(poses) - 1)][1] - poses[ind][1]),
|
| 136 |
+
]
|
| 137 |
+
|
| 138 |
+
eiplus1 = [
|
| 139 |
+
poses[(ind + 2) % (len(poses) - 1)][0] - poses[(ind + 1) % (len(poses) - 1)][0],
|
| 140 |
+
poses[(ind + 2) % (len(poses) - 1)][1] - poses[(ind + 1) % (len(poses) - 1)][1],
|
| 141 |
+
]
|
| 142 |
+
|
| 143 |
+
degrees.append((x_axis_angle(ei_minus), x_axis_angle(eiplus1)))
|
| 144 |
+
degrees.insert(0, degrees[-1])
|
| 145 |
+
degrees.pop(-1)
|
| 146 |
+
|
| 147 |
+
angles = []
|
| 148 |
+
for degree in degrees:
|
| 149 |
+
angles.append(((min(degree), max(degree)), (max(degree), min(degree))))
|
| 150 |
+
|
| 151 |
+
angles_to_semantics = []
|
| 152 |
+
for angle_ind, angle in enumerate(angles):
|
| 153 |
+
angle1 = angle[0]
|
| 154 |
+
angle2 = angle[1]
|
| 155 |
+
quadrant1 = get_quadrant(angle1)
|
| 156 |
+
quadrant2 = get_quadrant(angle2)
|
| 157 |
+
|
| 158 |
+
semantic1 = (
|
| 159 |
+
semantics[angle_ind][1] if quadrant1[0] >= 45 else -1,
|
| 160 |
+
semantics[angle_ind][0] if quadrant1[1] >= 45 else -1,
|
| 161 |
+
semantics[angle_ind][3] if quadrant1[2] >= 45 else -1,
|
| 162 |
+
semantics[angle_ind][2] if quadrant1[3] >= 45 else -1,
|
| 163 |
+
)
|
| 164 |
+
semantic2 = (
|
| 165 |
+
semantics[angle_ind][1] if quadrant2[0] >= 45 else -1,
|
| 166 |
+
semantics[angle_ind][0] if quadrant2[1] >= 45 else -1,
|
| 167 |
+
semantics[angle_ind][3] if quadrant2[2] >= 45 else -1,
|
| 168 |
+
semantics[angle_ind][2] if quadrant2[3] >= 45 else -1,
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
angle1_degree = sum(quadrant1)
|
| 172 |
+
angle2_degree = sum(quadrant2)
|
| 173 |
+
|
| 174 |
+
xproduct = cross_products[angle_ind]
|
| 175 |
+
|
| 176 |
+
if xproduct < 0:
|
| 177 |
+
if angle1_degree < angle2_degree:
|
| 178 |
+
angles_to_semantics.append(semantic1)
|
| 179 |
+
else:
|
| 180 |
+
angles_to_semantics.append(semantic2)
|
| 181 |
+
elif xproduct > 0:
|
| 182 |
+
if angle1_degree < angle2_degree:
|
| 183 |
+
angles_to_semantics.append(semantic2)
|
| 184 |
+
else:
|
| 185 |
+
angles_to_semantics.append(semantic1)
|
| 186 |
+
else:
|
| 187 |
+
assert 0
|
| 188 |
+
|
| 189 |
+
semantic_result = {}
|
| 190 |
+
for semantic_label in range(0, 13):
|
| 191 |
+
semantic_result[semantic_label] = 0
|
| 192 |
+
for everypoint_semantic in angles_to_semantics:
|
| 193 |
+
everypoint_semantic = [s for s in everypoint_semantic if s != -1]
|
| 194 |
+
for label in everypoint_semantic:
|
| 195 |
+
semantic_result[label] += 1 / len(everypoint_semantic)
|
| 196 |
+
|
| 197 |
+
this_cycle_semantic1 = sorted(semantic_result.items(), key=lambda d: d[1], reverse=True)
|
| 198 |
+
this_cycle_result = None
|
| 199 |
+
if this_cycle_semantic1[0][1] > this_cycle_semantic1[1][1]:
|
| 200 |
+
this_cycle_result = this_cycle_semantic1[0][0]
|
| 201 |
+
else:
|
| 202 |
+
this_cycle_results = [i[0] for i in this_cycle_semantic1 if i[1] == this_cycle_semantic1[0][1]]
|
| 203 |
+
this_cycle_result = this_cycle_results[random.randint(0, len(this_cycle_results) - 1)]
|
| 204 |
+
results.append(this_cycle_result)
|
| 205 |
+
|
| 206 |
+
return d_rev, simple_cycles, results
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def get_cycle_basis_and_semantic(best_result):
|
| 210 |
+
output_points, output_edges = get_results_float_with_semantic(best_result)
|
| 211 |
+
output_points = copy.deepcopy(output_points)
|
| 212 |
+
output_edges = copy.deepcopy(output_edges)
|
| 213 |
+
|
| 214 |
+
d = {}
|
| 215 |
+
for output_point_index, output_point in enumerate(output_points):
|
| 216 |
+
d[output_point] = output_point_index
|
| 217 |
+
d_rev = {}
|
| 218 |
+
for output_point_index, output_point in enumerate(output_points):
|
| 219 |
+
d_rev[output_point_index] = output_point
|
| 220 |
+
es = []
|
| 221 |
+
for output_edge in output_edges:
|
| 222 |
+
es.append((d[output_edge[0]], d[output_edge[1]]))
|
| 223 |
+
# print(d)
|
| 224 |
+
|
| 225 |
+
G = nx.Graph()
|
| 226 |
+
for e in es:
|
| 227 |
+
G.add_edge(e[0], e[1])
|
| 228 |
+
|
| 229 |
+
simple_cycles = []
|
| 230 |
+
simple_cycles_number = []
|
| 231 |
+
simple_cycles_semantics = []
|
| 232 |
+
bridges = list(nx.bridges(G))
|
| 233 |
+
for b in bridges:
|
| 234 |
+
if (d_rev[b[0]], d_rev[b[1]]) in output_edges:
|
| 235 |
+
output_edges.remove((d_rev[b[0]], d_rev[b[1]]))
|
| 236 |
+
es.remove((b[0], b[1]))
|
| 237 |
+
G.remove_edge(b[0], b[1])
|
| 238 |
+
if (d_rev[b[1]], d_rev[b[0]]) in output_edges:
|
| 239 |
+
output_edges.remove((d_rev[b[1]], d_rev[b[0]]))
|
| 240 |
+
es.remove((b[1], b[0]))
|
| 241 |
+
G.remove_edge(b[1], b[0])
|
| 242 |
+
connected_components = list(nx.connected_components(G))
|
| 243 |
+
for c in connected_components:
|
| 244 |
+
if len(c) == 1:
|
| 245 |
+
pass
|
| 246 |
+
else:
|
| 247 |
+
simple_cycles_c = []
|
| 248 |
+
simple_cycles_number_c = []
|
| 249 |
+
simple_cycle_semantics_c = []
|
| 250 |
+
# output_points_c = [p for p in output_points if d[p] in c]
|
| 251 |
+
output_edges_c = [e for e in output_edges if d[e[0]] in c or d[e[1]] in c]
|
| 252 |
+
output_edges_c_copy_for_traversing = copy.deepcopy(output_edges_c)
|
| 253 |
+
|
| 254 |
+
for edge_c in output_edges_c:
|
| 255 |
+
if edge_c not in output_edges_c_copy_for_traversing:
|
| 256 |
+
pass
|
| 257 |
+
else:
|
| 258 |
+
simple_cycle_semantics = []
|
| 259 |
+
simple_cycle = []
|
| 260 |
+
simple_cycle_number = []
|
| 261 |
+
point1 = edge_c[0]
|
| 262 |
+
point2 = edge_c[1]
|
| 263 |
+
point1_number = d[point1]
|
| 264 |
+
point2_number = d[point2]
|
| 265 |
+
|
| 266 |
+
initial_point = None
|
| 267 |
+
initial_point_number = None
|
| 268 |
+
if point1_number < point2_number:
|
| 269 |
+
initial_point = point1
|
| 270 |
+
initial_point_number = point1_number
|
| 271 |
+
else:
|
| 272 |
+
initial_point = point2
|
| 273 |
+
initial_point_number = point2_number
|
| 274 |
+
simple_cycle.append(initial_point)
|
| 275 |
+
simple_cycle_number.append(initial_point_number)
|
| 276 |
+
|
| 277 |
+
last_point = initial_point
|
| 278 |
+
|
| 279 |
+
current_point = None
|
| 280 |
+
current_point_number = None
|
| 281 |
+
if point1_number < point2_number:
|
| 282 |
+
current_point = point2
|
| 283 |
+
current_point_number = point2_number
|
| 284 |
+
else:
|
| 285 |
+
current_point = point1
|
| 286 |
+
current_point_number = point1_number
|
| 287 |
+
simple_cycle.append(current_point)
|
| 288 |
+
simple_cycle_number.append(current_point_number)
|
| 289 |
+
|
| 290 |
+
next_initial_point = copy.deepcopy(current_point)
|
| 291 |
+
|
| 292 |
+
next_point = None
|
| 293 |
+
next_point_number = None
|
| 294 |
+
|
| 295 |
+
while next_point != next_initial_point:
|
| 296 |
+
relevant_edges = []
|
| 297 |
+
for edge in output_edges_c:
|
| 298 |
+
if edge[0] == current_point or edge[1] == current_point:
|
| 299 |
+
relevant_edges.append(edge)
|
| 300 |
+
|
| 301 |
+
relevant_edges_degree = []
|
| 302 |
+
for relevant_edge in relevant_edges:
|
| 303 |
+
vec = None
|
| 304 |
+
if relevant_edge[0] == current_point:
|
| 305 |
+
vec = (
|
| 306 |
+
relevant_edge[1][0] - relevant_edge[0][0],
|
| 307 |
+
relevant_edge[1][1] - relevant_edge[0][1],
|
| 308 |
+
)
|
| 309 |
+
elif relevant_edge[1] == current_point:
|
| 310 |
+
vec = (
|
| 311 |
+
relevant_edge[0][0] - relevant_edge[1][0],
|
| 312 |
+
relevant_edge[0][1] - relevant_edge[1][1],
|
| 313 |
+
)
|
| 314 |
+
else:
|
| 315 |
+
assert 0
|
| 316 |
+
|
| 317 |
+
vec_degree = x_axis_angle(vec)
|
| 318 |
+
relevant_edges_degree.append(vec_degree)
|
| 319 |
+
|
| 320 |
+
vec_from_current_point_to_last_point_degree = None
|
| 321 |
+
for relevant_edge_ind, relevant_edge in enumerate(relevant_edges):
|
| 322 |
+
if relevant_edge == (current_point, last_point):
|
| 323 |
+
vec_from_current_point_to_last_point_degree = relevant_edges_degree[relevant_edge_ind]
|
| 324 |
+
relevant_edges.remove(relevant_edge)
|
| 325 |
+
relevant_edges_degree.remove(vec_from_current_point_to_last_point_degree)
|
| 326 |
+
elif relevant_edge == (last_point, current_point):
|
| 327 |
+
vec_from_current_point_to_last_point_degree = relevant_edges_degree[relevant_edge_ind]
|
| 328 |
+
relevant_edges.remove(relevant_edge)
|
| 329 |
+
relevant_edges_degree.remove(vec_from_current_point_to_last_point_degree)
|
| 330 |
+
else:
|
| 331 |
+
continue
|
| 332 |
+
|
| 333 |
+
rotate_deltas_counterclockwise = []
|
| 334 |
+
|
| 335 |
+
interior_angles = []
|
| 336 |
+
for relevant_edge_degree in relevant_edges_degree:
|
| 337 |
+
rotate_delta = rotate_degree_counterclockwise_from_counter_degree(
|
| 338 |
+
vec_from_current_point_to_last_point_degree, relevant_edge_degree
|
| 339 |
+
)
|
| 340 |
+
rotate_deltas_counterclockwise.append(rotate_delta)
|
| 341 |
+
interior_angles.append((relevant_edge_degree, vec_from_current_point_to_last_point_degree))
|
| 342 |
+
# print(rotate_deltas_counterclockwise)
|
| 343 |
+
|
| 344 |
+
max_rotate_index = rotate_deltas_counterclockwise.index(max(rotate_deltas_counterclockwise))
|
| 345 |
+
|
| 346 |
+
interior_angle_counterclockwise = interior_angles[max_rotate_index]
|
| 347 |
+
|
| 348 |
+
current_point_semantic = [
|
| 349 |
+
current_point[3],
|
| 350 |
+
current_point[2],
|
| 351 |
+
current_point[5],
|
| 352 |
+
current_point[4],
|
| 353 |
+
]
|
| 354 |
+
|
| 355 |
+
interior_angle_counterclockwise_degree_smaller = min(interior_angle_counterclockwise)
|
| 356 |
+
interior_angle_counterclockwise_degree_bigger = max(interior_angle_counterclockwise)
|
| 357 |
+
quadrant_smaller_to_bigger_counterclockwise = get_quadrant(
|
| 358 |
+
(
|
| 359 |
+
interior_angle_counterclockwise_degree_smaller,
|
| 360 |
+
interior_angle_counterclockwise_degree_bigger,
|
| 361 |
+
)
|
| 362 |
+
)
|
| 363 |
+
# print(quadrant_smaller_to_bigger_counterclockwise)
|
| 364 |
+
if interior_angle_counterclockwise.index(interior_angle_counterclockwise_degree_smaller) == 0:
|
| 365 |
+
pass
|
| 366 |
+
elif (
|
| 367 |
+
interior_angle_counterclockwise.index(interior_angle_counterclockwise_degree_smaller) == 1
|
| 368 |
+
):
|
| 369 |
+
quadrant_smaller_to_bigger_counterclockwise = (
|
| 370 |
+
90 - quadrant_smaller_to_bigger_counterclockwise[0],
|
| 371 |
+
90 - quadrant_smaller_to_bigger_counterclockwise[1],
|
| 372 |
+
90 - quadrant_smaller_to_bigger_counterclockwise[2],
|
| 373 |
+
90 - quadrant_smaller_to_bigger_counterclockwise[3],
|
| 374 |
+
)
|
| 375 |
+
else:
|
| 376 |
+
assert 0
|
| 377 |
+
|
| 378 |
+
current_point_semantic_valid = []
|
| 379 |
+
for qd, seman in enumerate(current_point_semantic):
|
| 380 |
+
if quadrant_smaller_to_bigger_counterclockwise[qd] >= 45:
|
| 381 |
+
current_point_semantic_valid.append(seman)
|
| 382 |
+
else:
|
| 383 |
+
current_point_semantic_valid.append(-1)
|
| 384 |
+
|
| 385 |
+
simple_cycle_semantics.append(current_point_semantic_valid)
|
| 386 |
+
|
| 387 |
+
max_rotate_edge = relevant_edges[max_rotate_index]
|
| 388 |
+
|
| 389 |
+
if max_rotate_edge[0] == current_point:
|
| 390 |
+
next_point = max_rotate_edge[1]
|
| 391 |
+
next_point_number = d[next_point]
|
| 392 |
+
elif max_rotate_edge[1] == current_point:
|
| 393 |
+
next_point = max_rotate_edge[0]
|
| 394 |
+
next_point_number = d[next_point]
|
| 395 |
+
else:
|
| 396 |
+
assert 0
|
| 397 |
+
|
| 398 |
+
last_point = current_point
|
| 399 |
+
current_point = next_point
|
| 400 |
+
current_point_number = next_point_number
|
| 401 |
+
simple_cycle.append(current_point)
|
| 402 |
+
simple_cycle_number.append(current_point_number)
|
| 403 |
+
|
| 404 |
+
for point_number_ind, point_number in enumerate(simple_cycle_number):
|
| 405 |
+
if point_number_ind < len(simple_cycle_number) - 1:
|
| 406 |
+
edge_number = (point_number, simple_cycle_number[point_number_ind + 1])
|
| 407 |
+
# print(simple_cycle_number)
|
| 408 |
+
if edge_number[0] < edge_number[1]:
|
| 409 |
+
if (
|
| 410 |
+
d_rev[edge_number[0]],
|
| 411 |
+
d_rev[edge_number[1]],
|
| 412 |
+
) in output_edges_c_copy_for_traversing:
|
| 413 |
+
output_edges_c_copy_for_traversing.remove(
|
| 414 |
+
(d_rev[edge_number[0]], d_rev[edge_number[1]])
|
| 415 |
+
)
|
| 416 |
+
elif (
|
| 417 |
+
d_rev[edge_number[1]],
|
| 418 |
+
d_rev[edge_number[0]],
|
| 419 |
+
) in output_edges_c_copy_for_traversing:
|
| 420 |
+
output_edges_c_copy_for_traversing.remove(
|
| 421 |
+
(d_rev[edge_number[1]], d_rev[edge_number[0]])
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
simple_cycle.pop(-1)
|
| 425 |
+
simple_cycle_number.pop(-1)
|
| 426 |
+
|
| 427 |
+
polygon_counterclockwise = [(int(p[0]), -int(p[1])) for p in simple_cycle]
|
| 428 |
+
polygon_counterclockwise.pop(-1)
|
| 429 |
+
# print('poly_area(polygon_counterclockwise)', poly_area(polygon_counterclockwise))
|
| 430 |
+
if poly_area(polygon_counterclockwise) > 0:
|
| 431 |
+
simple_cycles_c.append(simple_cycle)
|
| 432 |
+
simple_cycles_number_c.append(simple_cycle_number)
|
| 433 |
+
|
| 434 |
+
semantic_result = {}
|
| 435 |
+
for semantic_label in range(0, 13):
|
| 436 |
+
semantic_result[semantic_label] = 0
|
| 437 |
+
for everypoint_semantic in simple_cycle_semantics:
|
| 438 |
+
everypoint_semantic = [s for s in everypoint_semantic if s != -1]
|
| 439 |
+
for label in everypoint_semantic:
|
| 440 |
+
semantic_result[label] += 1 / len(everypoint_semantic)
|
| 441 |
+
# print(semantic_result)
|
| 442 |
+
del semantic_result[11]
|
| 443 |
+
del semantic_result[12]
|
| 444 |
+
|
| 445 |
+
this_cycle_semantic = sorted(semantic_result.items(), key=lambda d: d[1], reverse=True)
|
| 446 |
+
# print(this_cycle_semantic)
|
| 447 |
+
this_cycle_result = None
|
| 448 |
+
if this_cycle_semantic[0][1] > this_cycle_semantic[1][1]:
|
| 449 |
+
this_cycle_result = this_cycle_semantic[0][0]
|
| 450 |
+
else:
|
| 451 |
+
this_cycle_results = [
|
| 452 |
+
i[0] for i in this_cycle_semantic if i[1] == this_cycle_semantic[0][1]
|
| 453 |
+
]
|
| 454 |
+
this_cycle_result = this_cycle_results[random.randint(0, len(this_cycle_results) - 1)]
|
| 455 |
+
# print(this_cycle_result)
|
| 456 |
+
simple_cycle_semantics_c.append(this_cycle_result)
|
| 457 |
+
|
| 458 |
+
simple_cycles.extend(simple_cycles_c)
|
| 459 |
+
simple_cycles_number.extend(simple_cycles_number_c)
|
| 460 |
+
simple_cycles_semantics.extend(simple_cycle_semantics_c)
|
| 461 |
+
|
| 462 |
+
# print([[(int(j[0]), int(j[1])) for j in i] for i in simple_cycles])
|
| 463 |
+
|
| 464 |
+
# print(len(simple_cycles_number))
|
| 465 |
+
# print(simple_cycles_semantics)
|
| 466 |
+
|
| 467 |
+
return d_rev, simple_cycles, simple_cycles_semantics
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
def get_cycle_basis_and_semantic_2(best_result):
|
| 471 |
+
output_points, output_edges = get_results_float_with_semantic(best_result)
|
| 472 |
+
output_points = copy.deepcopy(output_points)
|
| 473 |
+
output_edges = copy.deepcopy(output_edges)
|
| 474 |
+
# print(output_points)
|
| 475 |
+
# print(output_edges)
|
| 476 |
+
# assert 0
|
| 477 |
+
d = {}
|
| 478 |
+
for output_point_index, output_point in enumerate(output_points):
|
| 479 |
+
d[output_point] = output_point_index
|
| 480 |
+
d_rev = {}
|
| 481 |
+
for output_point_index, output_point in enumerate(output_points):
|
| 482 |
+
d_rev[output_point_index] = output_point
|
| 483 |
+
es = []
|
| 484 |
+
for output_edge in output_edges:
|
| 485 |
+
es.append((d[output_edge[0]], d[output_edge[1]]))
|
| 486 |
+
# print(d)
|
| 487 |
+
|
| 488 |
+
G = nx.Graph()
|
| 489 |
+
for e in es:
|
| 490 |
+
G.add_edge(e[0], e[1])
|
| 491 |
+
|
| 492 |
+
simple_cycles = []
|
| 493 |
+
simple_cycles_number = []
|
| 494 |
+
simple_cycles_semantics = []
|
| 495 |
+
|
| 496 |
+
bridges = list(nx.bridges(G))
|
| 497 |
+
|
| 498 |
+
for b in bridges:
|
| 499 |
+
if (d_rev[b[0]], d_rev[b[1]]) in output_edges:
|
| 500 |
+
output_edges.remove((d_rev[b[0]], d_rev[b[1]]))
|
| 501 |
+
es.remove((b[0], b[1]))
|
| 502 |
+
G.remove_edge(b[0], b[1])
|
| 503 |
+
if (d_rev[b[1]], d_rev[b[0]]) in output_edges:
|
| 504 |
+
output_edges.remove((d_rev[b[1]], d_rev[b[0]]))
|
| 505 |
+
es.remove((b[1], b[0]))
|
| 506 |
+
G.remove_edge(b[1], b[0])
|
| 507 |
+
|
| 508 |
+
connected_components = list(nx.connected_components(G))
|
| 509 |
+
for c in connected_components:
|
| 510 |
+
if len(c) == 1:
|
| 511 |
+
pass
|
| 512 |
+
else:
|
| 513 |
+
simple_cycles_c = []
|
| 514 |
+
simple_cycles_number_c = []
|
| 515 |
+
simple_cycle_semantics_c = []
|
| 516 |
+
output_edges_c = [e for e in output_edges if d[e[0]] in c or d[e[1]] in c]
|
| 517 |
+
output_edges_c_copy_for_traversing = copy.deepcopy(output_edges_c)
|
| 518 |
+
|
| 519 |
+
for edge_c in output_edges_c:
|
| 520 |
+
if edge_c not in output_edges_c_copy_for_traversing:
|
| 521 |
+
pass
|
| 522 |
+
else:
|
| 523 |
+
simple_cycle_semantics = []
|
| 524 |
+
simple_cycle = []
|
| 525 |
+
simple_cycle_number = []
|
| 526 |
+
point1 = edge_c[0]
|
| 527 |
+
point2 = edge_c[1]
|
| 528 |
+
point1_number = d[point1]
|
| 529 |
+
point2_number = d[point2]
|
| 530 |
+
|
| 531 |
+
initial_point = None
|
| 532 |
+
initial_point_number = None
|
| 533 |
+
if point1_number < point2_number:
|
| 534 |
+
initial_point = point1
|
| 535 |
+
initial_point_number = point1_number
|
| 536 |
+
else:
|
| 537 |
+
initial_point = point2
|
| 538 |
+
initial_point_number = point2_number
|
| 539 |
+
simple_cycle.append(initial_point)
|
| 540 |
+
simple_cycle_number.append(initial_point_number)
|
| 541 |
+
|
| 542 |
+
last_point = initial_point
|
| 543 |
+
|
| 544 |
+
current_point = None
|
| 545 |
+
current_point_number = None
|
| 546 |
+
if point1_number < point2_number:
|
| 547 |
+
current_point = point2
|
| 548 |
+
current_point_number = point2_number
|
| 549 |
+
else:
|
| 550 |
+
current_point = point1
|
| 551 |
+
current_point_number = point1_number
|
| 552 |
+
simple_cycle.append(current_point)
|
| 553 |
+
simple_cycle_number.append(current_point_number)
|
| 554 |
+
|
| 555 |
+
next_initial_point = copy.deepcopy(current_point)
|
| 556 |
+
|
| 557 |
+
next_point = None
|
| 558 |
+
next_point_number = None
|
| 559 |
+
|
| 560 |
+
while next_point != next_initial_point:
|
| 561 |
+
relevant_edges = []
|
| 562 |
+
for edge in output_edges_c:
|
| 563 |
+
if edge[0] == current_point or edge[1] == current_point:
|
| 564 |
+
relevant_edges.append(edge)
|
| 565 |
+
|
| 566 |
+
relevant_edges_degree = []
|
| 567 |
+
for relevant_edge in relevant_edges:
|
| 568 |
+
vec = None
|
| 569 |
+
if relevant_edge[0] == current_point:
|
| 570 |
+
vec = (
|
| 571 |
+
relevant_edge[1][0] - relevant_edge[0][0],
|
| 572 |
+
relevant_edge[1][1] - relevant_edge[0][1],
|
| 573 |
+
)
|
| 574 |
+
elif relevant_edge[1] == current_point:
|
| 575 |
+
vec = (
|
| 576 |
+
relevant_edge[0][0] - relevant_edge[1][0],
|
| 577 |
+
relevant_edge[0][1] - relevant_edge[1][1],
|
| 578 |
+
)
|
| 579 |
+
else:
|
| 580 |
+
assert 0
|
| 581 |
+
|
| 582 |
+
vec_degree = x_axis_angle(vec)
|
| 583 |
+
relevant_edges_degree.append(vec_degree)
|
| 584 |
+
|
| 585 |
+
vec_from_current_point_to_last_point_degree = None
|
| 586 |
+
for relevant_edge_ind, relevant_edge in enumerate(relevant_edges):
|
| 587 |
+
if relevant_edge == (current_point, last_point):
|
| 588 |
+
vec_from_current_point_to_last_point_degree = relevant_edges_degree[relevant_edge_ind]
|
| 589 |
+
relevant_edges.remove(relevant_edge)
|
| 590 |
+
relevant_edges_degree.remove(vec_from_current_point_to_last_point_degree)
|
| 591 |
+
elif relevant_edge == (last_point, current_point):
|
| 592 |
+
vec_from_current_point_to_last_point_degree = relevant_edges_degree[relevant_edge_ind]
|
| 593 |
+
relevant_edges.remove(relevant_edge)
|
| 594 |
+
relevant_edges_degree.remove(vec_from_current_point_to_last_point_degree)
|
| 595 |
+
else:
|
| 596 |
+
continue
|
| 597 |
+
|
| 598 |
+
rotate_deltas_counterclockwise = []
|
| 599 |
+
interior_angles = []
|
| 600 |
+
for relevant_edge_degree in relevant_edges_degree:
|
| 601 |
+
rotate_delta = rotate_degree_counterclockwise_from_counter_degree(
|
| 602 |
+
vec_from_current_point_to_last_point_degree, relevant_edge_degree
|
| 603 |
+
)
|
| 604 |
+
rotate_deltas_counterclockwise.append(rotate_delta)
|
| 605 |
+
interior_angles.append((relevant_edge_degree, vec_from_current_point_to_last_point_degree))
|
| 606 |
+
# print(rotate_deltas_counterclockwise)
|
| 607 |
+
max_rotate_index = rotate_deltas_counterclockwise.index(max(rotate_deltas_counterclockwise))
|
| 608 |
+
interior_angle_counterclockwise = interior_angles[max_rotate_index]
|
| 609 |
+
current_point_semantic = [
|
| 610 |
+
current_point[3],
|
| 611 |
+
current_point[2],
|
| 612 |
+
current_point[5],
|
| 613 |
+
current_point[4],
|
| 614 |
+
]
|
| 615 |
+
interior_angle_counterclockwise_degree_smaller = min(interior_angle_counterclockwise)
|
| 616 |
+
interior_angle_counterclockwise_degree_bigger = max(interior_angle_counterclockwise)
|
| 617 |
+
quadrant_smaller_to_bigger_counterclockwise = get_quadrant(
|
| 618 |
+
(
|
| 619 |
+
interior_angle_counterclockwise_degree_smaller,
|
| 620 |
+
interior_angle_counterclockwise_degree_bigger,
|
| 621 |
+
)
|
| 622 |
+
)
|
| 623 |
+
if interior_angle_counterclockwise.index(interior_angle_counterclockwise_degree_smaller) == 0:
|
| 624 |
+
pass
|
| 625 |
+
elif (
|
| 626 |
+
interior_angle_counterclockwise.index(interior_angle_counterclockwise_degree_smaller) == 1
|
| 627 |
+
):
|
| 628 |
+
quadrant_smaller_to_bigger_counterclockwise = (
|
| 629 |
+
90 - quadrant_smaller_to_bigger_counterclockwise[0],
|
| 630 |
+
90 - quadrant_smaller_to_bigger_counterclockwise[1],
|
| 631 |
+
90 - quadrant_smaller_to_bigger_counterclockwise[2],
|
| 632 |
+
90 - quadrant_smaller_to_bigger_counterclockwise[3],
|
| 633 |
+
)
|
| 634 |
+
else:
|
| 635 |
+
assert 0
|
| 636 |
+
current_point_semantic_valid = []
|
| 637 |
+
for qd, seman in enumerate(current_point_semantic):
|
| 638 |
+
if 1:
|
| 639 |
+
current_point_semantic_valid.append(seman)
|
| 640 |
+
else:
|
| 641 |
+
current_point_semantic_valid.append(-1)
|
| 642 |
+
simple_cycle_semantics.append(current_point_semantic_valid)
|
| 643 |
+
|
| 644 |
+
max_rotate_edge = relevant_edges[max_rotate_index]
|
| 645 |
+
if max_rotate_edge[0] == current_point:
|
| 646 |
+
next_point = max_rotate_edge[1]
|
| 647 |
+
next_point_number = d[next_point]
|
| 648 |
+
elif max_rotate_edge[1] == current_point:
|
| 649 |
+
next_point = max_rotate_edge[0]
|
| 650 |
+
next_point_number = d[next_point]
|
| 651 |
+
else:
|
| 652 |
+
assert 0
|
| 653 |
+
|
| 654 |
+
last_point = current_point
|
| 655 |
+
current_point = next_point
|
| 656 |
+
current_point_number = next_point_number
|
| 657 |
+
simple_cycle.append(current_point)
|
| 658 |
+
simple_cycle_number.append(current_point_number)
|
| 659 |
+
|
| 660 |
+
for point_number_ind, point_number in enumerate(simple_cycle_number):
|
| 661 |
+
if point_number_ind < len(simple_cycle_number) - 1:
|
| 662 |
+
edge_number = (point_number, simple_cycle_number[point_number_ind + 1])
|
| 663 |
+
if edge_number[0] < edge_number[1]:
|
| 664 |
+
if (
|
| 665 |
+
d_rev[edge_number[0]],
|
| 666 |
+
d_rev[edge_number[1]],
|
| 667 |
+
) in output_edges_c_copy_for_traversing:
|
| 668 |
+
output_edges_c_copy_for_traversing.remove(
|
| 669 |
+
(d_rev[edge_number[0]], d_rev[edge_number[1]])
|
| 670 |
+
)
|
| 671 |
+
elif (
|
| 672 |
+
d_rev[edge_number[1]],
|
| 673 |
+
d_rev[edge_number[0]],
|
| 674 |
+
) in output_edges_c_copy_for_traversing:
|
| 675 |
+
output_edges_c_copy_for_traversing.remove(
|
| 676 |
+
(d_rev[edge_number[1]], d_rev[edge_number[0]])
|
| 677 |
+
)
|
| 678 |
+
|
| 679 |
+
simple_cycle.pop(-1)
|
| 680 |
+
simple_cycle_number.pop(-1)
|
| 681 |
+
polygon_counterclockwise = [(int(p[0]), -int(p[1])) for p in simple_cycle]
|
| 682 |
+
polygon_counterclockwise.pop(-1)
|
| 683 |
+
if poly_area(polygon_counterclockwise) > 0:
|
| 684 |
+
simple_cycles_c.append(simple_cycle)
|
| 685 |
+
simple_cycles_number_c.append(simple_cycle_number)
|
| 686 |
+
semantic_result = {}
|
| 687 |
+
for semantic_label in range(0, 13):
|
| 688 |
+
semantic_result[semantic_label] = 0
|
| 689 |
+
for everypoint_semantic in simple_cycle_semantics:
|
| 690 |
+
for _ in range(0, 13):
|
| 691 |
+
if _ in everypoint_semantic:
|
| 692 |
+
semantic_result[_] += 1
|
| 693 |
+
del semantic_result[11]
|
| 694 |
+
del semantic_result[12]
|
| 695 |
+
|
| 696 |
+
this_cycle_semantic = sorted(semantic_result.items(), key=lambda d: d[1], reverse=True)
|
| 697 |
+
this_cycle_result = None
|
| 698 |
+
if this_cycle_semantic[0][1] > this_cycle_semantic[1][1]:
|
| 699 |
+
this_cycle_result = this_cycle_semantic[0][0]
|
| 700 |
+
else:
|
| 701 |
+
this_cycle_results = [
|
| 702 |
+
i[0] for i in this_cycle_semantic if i[1] == this_cycle_semantic[0][1]
|
| 703 |
+
]
|
| 704 |
+
this_cycle_result = this_cycle_results[random.randint(0, len(this_cycle_results) - 1)]
|
| 705 |
+
simple_cycle_semantics_c.append(this_cycle_result)
|
| 706 |
+
|
| 707 |
+
simple_cycles.extend(simple_cycles_c)
|
| 708 |
+
simple_cycles_number.extend(simple_cycles_number_c)
|
| 709 |
+
simple_cycles_semantics.extend(simple_cycle_semantics_c)
|
| 710 |
+
|
| 711 |
+
return d_rev, simple_cycles, simple_cycles_semantics
|
| 712 |
+
|
| 713 |
+
|
| 714 |
+
def get_cycle_basis(best_result):
|
| 715 |
+
output_points, output_edges = get_results(best_result)
|
| 716 |
+
output_points = copy.deepcopy(output_points)
|
| 717 |
+
output_edges = copy.deepcopy(output_edges)
|
| 718 |
+
|
| 719 |
+
d = {}
|
| 720 |
+
for output_point_index, output_point in enumerate(output_points):
|
| 721 |
+
d[output_point] = output_point_index
|
| 722 |
+
d_rev = {}
|
| 723 |
+
for output_point_index, output_point in enumerate(output_points):
|
| 724 |
+
d_rev[output_point_index] = output_point
|
| 725 |
+
es = []
|
| 726 |
+
for output_edge in output_edges:
|
| 727 |
+
es.append((d[output_edge[0]], d[output_edge[1]]))
|
| 728 |
+
|
| 729 |
+
G = nx.Graph()
|
| 730 |
+
for e in es:
|
| 731 |
+
G.add_edge(e[0], e[1])
|
| 732 |
+
|
| 733 |
+
simple_cycles = []
|
| 734 |
+
simple_cycles_number = []
|
| 735 |
+
bridges = list(nx.bridges(G))
|
| 736 |
+
for b in bridges:
|
| 737 |
+
if (d_rev[b[0]], d_rev[b[1]]) in output_edges:
|
| 738 |
+
output_edges.remove((d_rev[b[0]], d_rev[b[1]]))
|
| 739 |
+
es.remove((b[0], b[1]))
|
| 740 |
+
G.remove_edge(b[0], b[1])
|
| 741 |
+
if (d_rev[b[1]], d_rev[b[0]]) in output_edges:
|
| 742 |
+
output_edges.remove((d_rev[b[1]], d_rev[b[0]]))
|
| 743 |
+
es.remove((b[1], b[0]))
|
| 744 |
+
G.remove_edge(b[1], b[0])
|
| 745 |
+
connected_components = list(nx.connected_components(G))
|
| 746 |
+
for c in connected_components:
|
| 747 |
+
if len(c) == 1:
|
| 748 |
+
pass
|
| 749 |
+
else:
|
| 750 |
+
simple_cycles_c = []
|
| 751 |
+
simple_cycles_number_c = []
|
| 752 |
+
output_edges_c = [e for e in output_edges if d[e[0]] in c or d[e[1]] in c]
|
| 753 |
+
output_edges_c_copy_for_traversing = copy.deepcopy(output_edges_c)
|
| 754 |
+
|
| 755 |
+
for edge_c in output_edges_c:
|
| 756 |
+
if edge_c not in output_edges_c_copy_for_traversing:
|
| 757 |
+
pass
|
| 758 |
+
else:
|
| 759 |
+
simple_cycle = []
|
| 760 |
+
simple_cycle_number = []
|
| 761 |
+
point1 = edge_c[0]
|
| 762 |
+
point2 = edge_c[1]
|
| 763 |
+
point1_number = d[point1]
|
| 764 |
+
point2_number = d[point2]
|
| 765 |
+
|
| 766 |
+
if point1_number < point2_number:
|
| 767 |
+
initial_point = point1
|
| 768 |
+
initial_point_number = point1_number
|
| 769 |
+
current_point = point2
|
| 770 |
+
current_point_number = point2_number
|
| 771 |
+
else:
|
| 772 |
+
initial_point = point2
|
| 773 |
+
initial_point_number = point2_number
|
| 774 |
+
current_point = point1
|
| 775 |
+
current_point_number = point1_number
|
| 776 |
+
|
| 777 |
+
simple_cycle.append(initial_point)
|
| 778 |
+
simple_cycle_number.append(initial_point_number)
|
| 779 |
+
simple_cycle.append(current_point)
|
| 780 |
+
simple_cycle_number.append(current_point_number)
|
| 781 |
+
|
| 782 |
+
last_point = initial_point
|
| 783 |
+
next_initial_point = copy.deepcopy(current_point)
|
| 784 |
+
next_point = None
|
| 785 |
+
|
| 786 |
+
while next_point != next_initial_point:
|
| 787 |
+
relevant_edges = []
|
| 788 |
+
for edge in output_edges_c:
|
| 789 |
+
if edge[0] == current_point or edge[1] == current_point:
|
| 790 |
+
relevant_edges.append(edge)
|
| 791 |
+
|
| 792 |
+
relevant_edges_degree = []
|
| 793 |
+
for relevant_edge in relevant_edges:
|
| 794 |
+
vec = None
|
| 795 |
+
if relevant_edge[0] == current_point:
|
| 796 |
+
vec = (
|
| 797 |
+
relevant_edge[1][0] - relevant_edge[0][0],
|
| 798 |
+
relevant_edge[1][1] - relevant_edge[0][1],
|
| 799 |
+
)
|
| 800 |
+
elif relevant_edge[1] == current_point:
|
| 801 |
+
vec = (
|
| 802 |
+
relevant_edge[0][0] - relevant_edge[1][0],
|
| 803 |
+
relevant_edge[0][1] - relevant_edge[1][1],
|
| 804 |
+
)
|
| 805 |
+
else:
|
| 806 |
+
assert 0
|
| 807 |
+
vec_degree = x_axis_angle(vec)
|
| 808 |
+
relevant_edges_degree.append(vec_degree)
|
| 809 |
+
|
| 810 |
+
vec_from_current_point_to_last_point_degree = None
|
| 811 |
+
for relevant_edge_ind, relevant_edge in enumerate(relevant_edges):
|
| 812 |
+
if relevant_edge == (current_point, last_point):
|
| 813 |
+
vec_from_current_point_to_last_point_degree = relevant_edges_degree[relevant_edge_ind]
|
| 814 |
+
relevant_edges.remove(relevant_edge)
|
| 815 |
+
relevant_edges_degree.remove(vec_from_current_point_to_last_point_degree)
|
| 816 |
+
elif relevant_edge == (last_point, current_point):
|
| 817 |
+
vec_from_current_point_to_last_point_degree = relevant_edges_degree[relevant_edge_ind]
|
| 818 |
+
relevant_edges.remove(relevant_edge)
|
| 819 |
+
relevant_edges_degree.remove(vec_from_current_point_to_last_point_degree)
|
| 820 |
+
else:
|
| 821 |
+
continue
|
| 822 |
+
|
| 823 |
+
rotate_deltas_counterclockwise = []
|
| 824 |
+
for relevant_edge_degree in relevant_edges_degree:
|
| 825 |
+
rotate_delta = rotate_degree_counterclockwise_from_counter_degree(
|
| 826 |
+
vec_from_current_point_to_last_point_degree, relevant_edge_degree
|
| 827 |
+
)
|
| 828 |
+
rotate_deltas_counterclockwise.append(rotate_delta)
|
| 829 |
+
|
| 830 |
+
max_rotate_index = rotate_deltas_counterclockwise.index(max(rotate_deltas_counterclockwise))
|
| 831 |
+
max_rotate_edge = relevant_edges[max_rotate_index]
|
| 832 |
+
|
| 833 |
+
if max_rotate_edge[0] == current_point:
|
| 834 |
+
next_point = max_rotate_edge[1]
|
| 835 |
+
next_point_number = d[next_point]
|
| 836 |
+
elif max_rotate_edge[1] == current_point:
|
| 837 |
+
next_point = max_rotate_edge[0]
|
| 838 |
+
next_point_number = d[next_point]
|
| 839 |
+
else:
|
| 840 |
+
assert 0
|
| 841 |
+
|
| 842 |
+
last_point = current_point
|
| 843 |
+
current_point = next_point
|
| 844 |
+
current_point_number = next_point_number
|
| 845 |
+
simple_cycle.append(current_point)
|
| 846 |
+
simple_cycle_number.append(current_point_number)
|
| 847 |
+
|
| 848 |
+
for point_number_ind, point_number in enumerate(simple_cycle_number):
|
| 849 |
+
if point_number_ind < len(simple_cycle_number) - 1:
|
| 850 |
+
edge_number = (point_number, simple_cycle_number[point_number_ind + 1])
|
| 851 |
+
if edge_number[0] < edge_number[1]:
|
| 852 |
+
if (
|
| 853 |
+
d_rev[edge_number[0]],
|
| 854 |
+
d_rev[edge_number[1]],
|
| 855 |
+
) in output_edges_c_copy_for_traversing:
|
| 856 |
+
output_edges_c_copy_for_traversing.remove(
|
| 857 |
+
(d_rev[edge_number[0]], d_rev[edge_number[1]])
|
| 858 |
+
)
|
| 859 |
+
elif (
|
| 860 |
+
d_rev[edge_number[1]],
|
| 861 |
+
d_rev[edge_number[0]],
|
| 862 |
+
) in output_edges_c_copy_for_traversing:
|
| 863 |
+
output_edges_c_copy_for_traversing.remove(
|
| 864 |
+
(d_rev[edge_number[1]], d_rev[edge_number[0]])
|
| 865 |
+
)
|
| 866 |
+
|
| 867 |
+
simple_cycle.pop(-1)
|
| 868 |
+
simple_cycle_number.pop(-1)
|
| 869 |
+
|
| 870 |
+
polygon_counterclockwise = [(int(p[0]), -int(p[1])) for p in simple_cycle]
|
| 871 |
+
polygon_counterclockwise.pop(-1)
|
| 872 |
+
if poly_area(polygon_counterclockwise) > 0:
|
| 873 |
+
simple_cycles_c.append(simple_cycle)
|
| 874 |
+
simple_cycles_number_c.append(simple_cycle_number)
|
| 875 |
+
|
| 876 |
+
simple_cycles.extend(simple_cycles_c)
|
| 877 |
+
simple_cycles_number.extend(simple_cycles_number_c)
|
| 878 |
+
|
| 879 |
+
return d_rev, simple_cycles, simple_cycles_number
|
data_preprocess/raster2graph/util/image_id_dict.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data_preprocess/raster2graph/util/math_utils.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def clip(number, _min, _max):
|
| 2 |
+
if number <= _min:
|
| 3 |
+
return _min
|
| 4 |
+
elif number >= _max:
|
| 5 |
+
return _max
|
| 6 |
+
else:
|
| 7 |
+
return number
|
data_preprocess/raster2graph/util/mean_std.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
mean = [0.920, 0.913, 0.891]
|
| 2 |
+
std = [0.214, 0.216, 0.228]
|
data_preprocess/raster2graph/util/metric_utils.py
ADDED
|
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
from shapely.geometry import Polygon
|
| 5 |
+
from util.geom_utils import poly_iou
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def calculate_AP(valid_results, ground_truths, confidence_final):
|
| 9 |
+
ground_truths_copy = copy.deepcopy(ground_truths)
|
| 10 |
+
all_preds = []
|
| 11 |
+
for image_id, image_pred in valid_results.items():
|
| 12 |
+
for i in range(len(image_pred["points"])):
|
| 13 |
+
pred = {}
|
| 14 |
+
pred["score"] = image_pred["scores"][i].item()
|
| 15 |
+
pred["point"] = tuple(image_pred["points"][i].tolist())
|
| 16 |
+
pred["size"] = tuple(image_pred["size"].tolist())
|
| 17 |
+
pred["image_id"] = image_id.item()
|
| 18 |
+
all_preds.append(pred)
|
| 19 |
+
all_preds = sorted(all_preds, key=lambda x: x["score"], reverse=True)
|
| 20 |
+
|
| 21 |
+
all_preds = [pred for pred in all_preds if pred["score"] > confidence_final]
|
| 22 |
+
|
| 23 |
+
all_metrics = []
|
| 24 |
+
for n in range(1, len(all_preds) + 1):
|
| 25 |
+
ground_truths = copy.deepcopy(ground_truths_copy)
|
| 26 |
+
|
| 27 |
+
sub_preds = all_preds[0:n]
|
| 28 |
+
|
| 29 |
+
TP = 0
|
| 30 |
+
FP = 0
|
| 31 |
+
FN = 0
|
| 32 |
+
for pred in sub_preds:
|
| 33 |
+
pred_point = pred["point"]
|
| 34 |
+
img_size = (pred["size"][1], pred["size"][0])
|
| 35 |
+
img_id = pred["image_id"]
|
| 36 |
+
dist_threshold = (img_size[0] * 0.01, img_size[1] * 0.01)
|
| 37 |
+
gt = [tuple(gt_point) for gt_point in ground_truths[img_id]["points"].tolist()]
|
| 38 |
+
gt_copy = copy.deepcopy(gt)
|
| 39 |
+
euc_dists = {}
|
| 40 |
+
dists = {}
|
| 41 |
+
for gt_point in gt_copy:
|
| 42 |
+
if gt_point[2] == 0:
|
| 43 |
+
dist = (abs(pred_point[0] - gt_point[0]), abs(pred_point[1] - gt_point[1]))
|
| 44 |
+
euc_dist = math.sqrt(dist[0] ** 2 + dist[1] ** 2)
|
| 45 |
+
euc_dists[gt_point] = euc_dist
|
| 46 |
+
dists[gt_point] = dist
|
| 47 |
+
euc_dists = sorted(euc_dists.items(), key=lambda x: x[1])
|
| 48 |
+
if len(euc_dists) == 0:
|
| 49 |
+
FP += 1
|
| 50 |
+
continue
|
| 51 |
+
nearest_gt_point = euc_dists[0][0]
|
| 52 |
+
min_dist = dists[nearest_gt_point]
|
| 53 |
+
if min_dist[0] < dist_threshold[0] and min_dist[1] < dist_threshold[1]:
|
| 54 |
+
gtip = ground_truths[img_id]["points"]
|
| 55 |
+
for i, p in enumerate(gtip):
|
| 56 |
+
if (
|
| 57 |
+
p[0].item() == nearest_gt_point[0]
|
| 58 |
+
and p[1].item() == nearest_gt_point[1]
|
| 59 |
+
and p[2].item() == nearest_gt_point[2]
|
| 60 |
+
):
|
| 61 |
+
# print('qqq', p, nearest_gt_point)
|
| 62 |
+
gtip[i, 2] = 1
|
| 63 |
+
break
|
| 64 |
+
ground_truths[img_id]["points"] = gtip
|
| 65 |
+
# print('rrr', ground_truths[img_id]['points'])
|
| 66 |
+
TP += 1
|
| 67 |
+
continue
|
| 68 |
+
FP += 1
|
| 69 |
+
for img_id, points in ground_truths.items():
|
| 70 |
+
points = points["points"]
|
| 71 |
+
for point in points:
|
| 72 |
+
if point[2] == 0:
|
| 73 |
+
FN += 1
|
| 74 |
+
precision = TP / (TP + FP)
|
| 75 |
+
recall = TP / (TP + FN)
|
| 76 |
+
# print(n, TP, FP, FN, precision, recall)
|
| 77 |
+
all_metrics.append((precision, recall))
|
| 78 |
+
|
| 79 |
+
all_metrics = sorted(all_metrics, key=lambda x: (x[1], x[0]))
|
| 80 |
+
p_r_curve_points = {}
|
| 81 |
+
for point in all_metrics:
|
| 82 |
+
p_r_curve_points[point[1]] = point[0]
|
| 83 |
+
p_r_curve_points[0] = 1
|
| 84 |
+
p_r_curve_points = sorted(p_r_curve_points.items(), key=lambda d: d[0])
|
| 85 |
+
AP = 0
|
| 86 |
+
for i, rp in enumerate(p_r_curve_points):
|
| 87 |
+
r = rp[0]
|
| 88 |
+
p = rp[1]
|
| 89 |
+
if i > 0:
|
| 90 |
+
small_rectangular_area = (r - p_r_curve_points[i - 1][0]) * p
|
| 91 |
+
AP += small_rectangular_area
|
| 92 |
+
|
| 93 |
+
return AP
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def get_results(best_result):
|
| 97 |
+
if 1:
|
| 98 |
+
preds = best_result[2]
|
| 99 |
+
output_points = []
|
| 100 |
+
output_edges = []
|
| 101 |
+
for triplet in preds:
|
| 102 |
+
this_preds = triplet[0]
|
| 103 |
+
last_edges = triplet[1]
|
| 104 |
+
this_edges = triplet[2]
|
| 105 |
+
for this_pred in this_preds:
|
| 106 |
+
point = tuple(this_pred["points"].int().tolist())
|
| 107 |
+
output_points.append(point)
|
| 108 |
+
for last_edge in last_edges:
|
| 109 |
+
point1 = tuple(last_edge[0]["points"].int().tolist())
|
| 110 |
+
point2 = tuple(last_edge[1]["points"].int().tolist())
|
| 111 |
+
edge = (point1, point2)
|
| 112 |
+
output_edges.append(edge)
|
| 113 |
+
for this_edge in this_edges:
|
| 114 |
+
point1 = tuple(this_edge[0]["points"].int().tolist())
|
| 115 |
+
point2 = tuple(this_edge[1]["points"].int().tolist())
|
| 116 |
+
edge = (point1, point2)
|
| 117 |
+
output_edges.append(edge)
|
| 118 |
+
return output_points, output_edges
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def get_results_visual(best_result):
|
| 122 |
+
if 1:
|
| 123 |
+
preds = best_result[2]
|
| 124 |
+
output_points = []
|
| 125 |
+
output_edges = []
|
| 126 |
+
for layer_index, triplet in enumerate(preds):
|
| 127 |
+
this_preds = triplet[0]
|
| 128 |
+
last_edges = triplet[1]
|
| 129 |
+
this_edges = triplet[2]
|
| 130 |
+
for this_pred in this_preds:
|
| 131 |
+
point = tuple(this_pred["points"].int().tolist())
|
| 132 |
+
output_points.append([layer_index, point])
|
| 133 |
+
for last_edge in last_edges:
|
| 134 |
+
point1 = tuple(last_edge[0]["points"].int().tolist())
|
| 135 |
+
point2 = tuple(last_edge[1]["points"].int().tolist())
|
| 136 |
+
edge = (point1, point2)
|
| 137 |
+
output_edges.append([layer_index, edge])
|
| 138 |
+
for this_edge in this_edges:
|
| 139 |
+
point1 = tuple(this_edge[0]["points"].int().tolist())
|
| 140 |
+
point2 = tuple(this_edge[1]["points"].int().tolist())
|
| 141 |
+
edge = (point1, point2)
|
| 142 |
+
output_edges.append([layer_index, edge])
|
| 143 |
+
return output_points, output_edges, len(preds)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def get_results_float_with_semantic(best_result):
|
| 147 |
+
preds = best_result[2]
|
| 148 |
+
output_points = []
|
| 149 |
+
output_edges = []
|
| 150 |
+
for triplet in preds:
|
| 151 |
+
this_preds = triplet[0]
|
| 152 |
+
last_edges = triplet[1]
|
| 153 |
+
this_edges = triplet[2]
|
| 154 |
+
for this_pred in this_preds:
|
| 155 |
+
point = (
|
| 156 |
+
this_pred["points"].tolist()[0],
|
| 157 |
+
this_pred["points"].tolist()[1],
|
| 158 |
+
this_pred["semantic_left_up"].item(),
|
| 159 |
+
this_pred["semantic_right_up"].item(),
|
| 160 |
+
this_pred["semantic_right_down"].item(),
|
| 161 |
+
this_pred["semantic_left_down"].item(),
|
| 162 |
+
)
|
| 163 |
+
output_points.append(point)
|
| 164 |
+
for last_edge in last_edges:
|
| 165 |
+
point1 = (
|
| 166 |
+
last_edge[0]["points"].tolist()[0],
|
| 167 |
+
last_edge[0]["points"].tolist()[1],
|
| 168 |
+
last_edge[0]["semantic_left_up"].item(),
|
| 169 |
+
last_edge[0]["semantic_right_up"].item(),
|
| 170 |
+
last_edge[0]["semantic_right_down"].item(),
|
| 171 |
+
last_edge[0]["semantic_left_down"].item(),
|
| 172 |
+
)
|
| 173 |
+
point2 = (
|
| 174 |
+
last_edge[1]["points"].tolist()[0],
|
| 175 |
+
last_edge[1]["points"].tolist()[1],
|
| 176 |
+
last_edge[1]["semantic_left_up"].item(),
|
| 177 |
+
last_edge[1]["semantic_right_up"].item(),
|
| 178 |
+
last_edge[1]["semantic_right_down"].item(),
|
| 179 |
+
last_edge[1]["semantic_left_down"].item(),
|
| 180 |
+
)
|
| 181 |
+
edge = (point1, point2)
|
| 182 |
+
output_edges.append(edge)
|
| 183 |
+
for this_edge in this_edges:
|
| 184 |
+
point1 = (
|
| 185 |
+
this_edge[0]["points"].tolist()[0],
|
| 186 |
+
this_edge[0]["points"].tolist()[1],
|
| 187 |
+
this_edge[0]["semantic_left_up"].item(),
|
| 188 |
+
this_edge[0]["semantic_right_up"].item(),
|
| 189 |
+
this_edge[0]["semantic_right_down"].item(),
|
| 190 |
+
this_edge[0]["semantic_left_down"].item(),
|
| 191 |
+
)
|
| 192 |
+
point2 = (
|
| 193 |
+
this_edge[1]["points"].tolist()[0],
|
| 194 |
+
this_edge[1]["points"].tolist()[1],
|
| 195 |
+
this_edge[1]["semantic_left_up"].item(),
|
| 196 |
+
this_edge[1]["semantic_right_up"].item(),
|
| 197 |
+
this_edge[1]["semantic_right_down"].item(),
|
| 198 |
+
this_edge[1]["semantic_left_down"].item(),
|
| 199 |
+
)
|
| 200 |
+
edge = (point1, point2)
|
| 201 |
+
output_edges.append(edge)
|
| 202 |
+
|
| 203 |
+
return output_points, output_edges
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def calculate_single_sample(
|
| 207 |
+
best_result, graph, target_d_rev, target_simple_cycles, target_results, d_rev, simple_cycles, results
|
| 208 |
+
):
|
| 209 |
+
output_points, output_edges = get_results(best_result)
|
| 210 |
+
gt_points = [k for k, v in graph.items()]
|
| 211 |
+
gt_edges = []
|
| 212 |
+
for k, v in graph.items():
|
| 213 |
+
for adj in v:
|
| 214 |
+
if adj != (-1, -1):
|
| 215 |
+
gt_edge = (k, adj)
|
| 216 |
+
if (adj, k) not in gt_edges:
|
| 217 |
+
gt_edges.append(gt_edge)
|
| 218 |
+
|
| 219 |
+
points_TP = 0
|
| 220 |
+
points_FP = 0
|
| 221 |
+
points_FN = 0
|
| 222 |
+
dist_error_x = 0
|
| 223 |
+
dist_error_y = 0
|
| 224 |
+
dist_error_l2 = 0
|
| 225 |
+
gt_points_copy = copy.deepcopy(gt_points)
|
| 226 |
+
threshold = 5
|
| 227 |
+
for output_point in output_points:
|
| 228 |
+
matched = False
|
| 229 |
+
for gt_point in gt_points:
|
| 230 |
+
if (abs(output_point[0] - gt_point[0]) <= threshold) and (abs(output_point[1] - gt_point[1]) <= threshold):
|
| 231 |
+
if gt_point in gt_points_copy:
|
| 232 |
+
points_TP += 1
|
| 233 |
+
dist_error_x += abs(output_point[0] - gt_point[0])
|
| 234 |
+
dist_error_y += abs(output_point[1] - gt_point[1])
|
| 235 |
+
dist_error_l2 += (
|
| 236 |
+
abs(output_point[0] - gt_point[0]) ** 2 + abs(output_point[1] - gt_point[1]) ** 2
|
| 237 |
+
) ** 0.5
|
| 238 |
+
matched = True
|
| 239 |
+
gt_points_copy.remove(gt_point)
|
| 240 |
+
break
|
| 241 |
+
if not matched:
|
| 242 |
+
points_FP += 1
|
| 243 |
+
points_FN = len(gt_points) - points_TP
|
| 244 |
+
|
| 245 |
+
edges_TP = 0
|
| 246 |
+
edges_FP = 0
|
| 247 |
+
edges_FN = 0
|
| 248 |
+
gt_edges_copy = copy.deepcopy(gt_edges)
|
| 249 |
+
threshold = 5
|
| 250 |
+
for output_edge in output_edges:
|
| 251 |
+
matched = False
|
| 252 |
+
for gt_edge in gt_edges:
|
| 253 |
+
if (
|
| 254 |
+
(
|
| 255 |
+
(abs(output_edge[0][0] - gt_edge[0][0]) <= threshold)
|
| 256 |
+
and (abs(output_edge[0][1] - gt_edge[0][1]) <= threshold)
|
| 257 |
+
)
|
| 258 |
+
and (
|
| 259 |
+
(abs(output_edge[1][0] - gt_edge[1][0]) <= threshold)
|
| 260 |
+
and (abs(output_edge[1][1] - gt_edge[1][1]) <= threshold)
|
| 261 |
+
)
|
| 262 |
+
) or (
|
| 263 |
+
(
|
| 264 |
+
(abs(output_edge[0][0] - gt_edge[1][0]) <= threshold)
|
| 265 |
+
and (abs(output_edge[0][1] - gt_edge[1][1]) <= threshold)
|
| 266 |
+
)
|
| 267 |
+
and (
|
| 268 |
+
(abs(output_edge[1][0] - gt_edge[0][0]) <= threshold)
|
| 269 |
+
and (abs(output_edge[1][1] - gt_edge[0][1]) <= threshold)
|
| 270 |
+
)
|
| 271 |
+
):
|
| 272 |
+
if gt_edge in gt_edges_copy:
|
| 273 |
+
edges_TP += 1
|
| 274 |
+
matched = True
|
| 275 |
+
gt_edges_copy.remove(gt_edge)
|
| 276 |
+
break
|
| 277 |
+
if not matched:
|
| 278 |
+
edges_FP += 1
|
| 279 |
+
edges_FN = len(gt_edges) - edges_TP
|
| 280 |
+
|
| 281 |
+
regions_TP = 0
|
| 282 |
+
regions_FP = 0
|
| 283 |
+
regions_FN = 0
|
| 284 |
+
rooms_TP = 0
|
| 285 |
+
rooms_FP = 0
|
| 286 |
+
rooms_FN = 0
|
| 287 |
+
gt_regions = []
|
| 288 |
+
output_regions = []
|
| 289 |
+
|
| 290 |
+
for target_simple_cycle in target_simple_cycles:
|
| 291 |
+
target_polyg = [(point_i[0], point_i[1]) for point_i in target_simple_cycle]
|
| 292 |
+
gt_regions.append(target_polyg)
|
| 293 |
+
|
| 294 |
+
for simple_cycle in simple_cycles:
|
| 295 |
+
polyg = [(point_i[0], point_i[1]) for point_i in simple_cycle]
|
| 296 |
+
polyg.pop(-1)
|
| 297 |
+
output_regions.append(polyg)
|
| 298 |
+
gt_regions_copy = copy.deepcopy(gt_regions)
|
| 299 |
+
iou_threshold = 0.7
|
| 300 |
+
for output_region_i, output_region in enumerate(output_regions):
|
| 301 |
+
matched = False
|
| 302 |
+
for gt_region_i, gt_region in enumerate(gt_regions):
|
| 303 |
+
if poly_iou(Polygon(gt_region), Polygon(output_region)) >= iou_threshold:
|
| 304 |
+
if gt_region in gt_regions_copy:
|
| 305 |
+
regions_TP += 1
|
| 306 |
+
if target_results[gt_region_i] == results[output_region_i]:
|
| 307 |
+
rooms_TP += 1
|
| 308 |
+
else:
|
| 309 |
+
rooms_FP += 1
|
| 310 |
+
matched = True
|
| 311 |
+
gt_regions_copy.remove(gt_region)
|
| 312 |
+
break
|
| 313 |
+
if not matched:
|
| 314 |
+
regions_FP += 1
|
| 315 |
+
rooms_FP += 1
|
| 316 |
+
regions_FN = len(gt_regions) - regions_TP
|
| 317 |
+
rooms_FN = len(gt_regions) - rooms_TP
|
| 318 |
+
# print(regions_TP, regions_FP, regions_FN)
|
| 319 |
+
# print(rooms_TP, rooms_FP, rooms_FN)
|
| 320 |
+
|
| 321 |
+
dist_error = (0, 0, 0)
|
| 322 |
+
if points_TP > 0:
|
| 323 |
+
dist_error = (dist_error_x, dist_error_y, dist_error_l2)
|
| 324 |
+
return (
|
| 325 |
+
points_TP,
|
| 326 |
+
points_FP,
|
| 327 |
+
points_FN,
|
| 328 |
+
edges_TP,
|
| 329 |
+
edges_FP,
|
| 330 |
+
edges_FN,
|
| 331 |
+
dist_error,
|
| 332 |
+
regions_TP,
|
| 333 |
+
regions_FP,
|
| 334 |
+
regions_FN,
|
| 335 |
+
rooms_TP,
|
| 336 |
+
rooms_FP,
|
| 337 |
+
rooms_FN,
|
| 338 |
+
)
|
data_preprocess/raster2graph/util/semantics_dict.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
semantics_dict = {
|
| 2 |
+
"living_room": 1,
|
| 3 |
+
"kitchen": 2,
|
| 4 |
+
"bedroom": 3,
|
| 5 |
+
"bathroom": 4,
|
| 6 |
+
"restroom": 5,
|
| 7 |
+
"balcony": 6,
|
| 8 |
+
"closet": 7,
|
| 9 |
+
"corridor": 8,
|
| 10 |
+
"washing_room": 9,
|
| 11 |
+
"PS": 10,
|
| 12 |
+
"outside": 11,
|
| 13 |
+
"wall": 12,
|
| 14 |
+
"no_type": 0,
|
| 15 |
+
}
|
| 16 |
+
semantics_dict_rev = {
|
| 17 |
+
0: "no_type",
|
| 18 |
+
1: "living_room",
|
| 19 |
+
2: "kitchen",
|
| 20 |
+
3: "bedroom",
|
| 21 |
+
4: "bathroom",
|
| 22 |
+
5: "restroom",
|
| 23 |
+
6: "balcony",
|
| 24 |
+
7: "closet",
|
| 25 |
+
8: "corridor",
|
| 26 |
+
9: "washing_room",
|
| 27 |
+
10: "PS",
|
| 28 |
+
11: "outside",
|
| 29 |
+
12: "wall",
|
| 30 |
+
}
|
| 31 |
+
semantics_dict_color = {
|
| 32 |
+
"living_room": (0, 0, 220),
|
| 33 |
+
"kitchen": (0, 220, 220),
|
| 34 |
+
"bedroom": (0, 220, 0),
|
| 35 |
+
"bathroom": (220, 220, 0),
|
| 36 |
+
"restroom": (220, 0, 0),
|
| 37 |
+
"balcony": (220, 0, 220),
|
| 38 |
+
"closet": (110, 0, 110),
|
| 39 |
+
"corridor": (110, 0, 0),
|
| 40 |
+
"washing_room": (0, 0, 110),
|
| 41 |
+
"PS": (0, 110, 110),
|
| 42 |
+
"outside": (0, 0, 0),
|
| 43 |
+
"wall": (110, 110, 110),
|
| 44 |
+
"no_type": (20, 20, 20),
|
| 45 |
+
}
|
data_preprocess/stru3d/PointCloudReaderPanorama.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
import open3d as o3d
|
| 6 |
+
|
| 7 |
+
NUM_SECTIONS = -1
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class PointCloudReaderPanorama:
|
| 11 |
+
def __init__(self, path, resolution="full", random_level=0, generate_color=False, generate_normal=False):
|
| 12 |
+
self.path = path
|
| 13 |
+
self.random_level = random_level
|
| 14 |
+
self.resolution = resolution
|
| 15 |
+
self.generate_color = generate_color
|
| 16 |
+
self.generate_normal = generate_normal
|
| 17 |
+
sections = [p for p in os.listdir(os.path.join(path, "2D_rendering"))]
|
| 18 |
+
self.depth_paths = [
|
| 19 |
+
os.path.join(*[path, "2D_rendering", p, "panorama", self.resolution, "depth.png"]) for p in sections
|
| 20 |
+
]
|
| 21 |
+
self.rgb_paths = [
|
| 22 |
+
os.path.join(*[path, "2D_rendering", p, "panorama", self.resolution, "rgb_coldlight.png"])
|
| 23 |
+
for p in sections
|
| 24 |
+
]
|
| 25 |
+
self.normal_paths = [
|
| 26 |
+
os.path.join(*[path, "2D_rendering", p, "panorama", self.resolution, "normal.png"]) for p in sections
|
| 27 |
+
]
|
| 28 |
+
self.camera_paths = [os.path.join(*[path, "2D_rendering", p, "panorama", "camera_xyz.txt"]) for p in sections]
|
| 29 |
+
self.camera_centers = self.read_camera_center()
|
| 30 |
+
self.point_cloud = self.generate_point_cloud(
|
| 31 |
+
self.random_level, color=self.generate_color, normal=self.generate_normal
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
def read_camera_center(self):
|
| 35 |
+
camera_centers = []
|
| 36 |
+
for i in range(len(self.camera_paths)):
|
| 37 |
+
with open(self.camera_paths[i], "r") as f:
|
| 38 |
+
line = f.readline()
|
| 39 |
+
center = list(map(float, line.strip().split(" ")))
|
| 40 |
+
camera_centers.append(np.asarray([center[0], center[1], center[2]]))
|
| 41 |
+
return camera_centers
|
| 42 |
+
|
| 43 |
+
def generate_point_cloud(self, random_level=0, color=False, normal=False):
|
| 44 |
+
coords = []
|
| 45 |
+
colors = []
|
| 46 |
+
points = {}
|
| 47 |
+
# normals = []
|
| 48 |
+
|
| 49 |
+
# Getting Coordinates
|
| 50 |
+
for i in range(len(self.depth_paths)):
|
| 51 |
+
depth_img = cv2.imread(self.depth_paths[i], cv2.IMREAD_ANYDEPTH | cv2.IMREAD_ANYCOLOR)
|
| 52 |
+
x_tick = 180.0 / depth_img.shape[0]
|
| 53 |
+
y_tick = 360.0 / depth_img.shape[1]
|
| 54 |
+
|
| 55 |
+
rgb_img = cv2.imread(self.rgb_paths[i])
|
| 56 |
+
rgb_img = cv2.cvtColor(rgb_img, code=cv2.COLOR_BGR2RGB)
|
| 57 |
+
# normal_img = cv2.imread(self.normal_paths[i])
|
| 58 |
+
|
| 59 |
+
for x in range(0, depth_img.shape[0]):
|
| 60 |
+
for y in range(0, depth_img.shape[1]):
|
| 61 |
+
# need 90 - -09
|
| 62 |
+
alpha = 90 - (x * x_tick)
|
| 63 |
+
beta = y * y_tick - 180
|
| 64 |
+
|
| 65 |
+
depth = depth_img[x, y] + np.random.random() * random_level
|
| 66 |
+
|
| 67 |
+
if depth > 500.0:
|
| 68 |
+
z_offset = depth * np.sin(np.deg2rad(alpha))
|
| 69 |
+
xy_offset = depth * np.cos(np.deg2rad(alpha))
|
| 70 |
+
x_offset = xy_offset * np.sin(np.deg2rad(beta))
|
| 71 |
+
y_offset = xy_offset * np.cos(np.deg2rad(beta))
|
| 72 |
+
point = np.asarray([x_offset, y_offset, z_offset])
|
| 73 |
+
coords.append(point + self.camera_centers[i])
|
| 74 |
+
colors.append(rgb_img[x, y])
|
| 75 |
+
# normals.append(normalize(normal_img[x, y].reshape(-1, 1)).ravel())
|
| 76 |
+
|
| 77 |
+
coords = np.asarray(coords)
|
| 78 |
+
colors = np.asarray(colors) / 255.0
|
| 79 |
+
# normals = np.asarray(normals)
|
| 80 |
+
|
| 81 |
+
coords[:, :2] = np.round(coords[:, :2] / 10) * 10.0
|
| 82 |
+
coords[:, 2] = np.round(coords[:, 2] / 100) * 100.0
|
| 83 |
+
unique_coords, unique_ind = np.unique(coords, return_index=True, axis=0)
|
| 84 |
+
|
| 85 |
+
coords = coords[unique_ind]
|
| 86 |
+
colors = colors[unique_ind]
|
| 87 |
+
# normals = normals[unique_ind]
|
| 88 |
+
|
| 89 |
+
points["coords"] = coords
|
| 90 |
+
points["colors"] = colors
|
| 91 |
+
# points['normals'] = normals
|
| 92 |
+
|
| 93 |
+
print("Pointcloud size:", points["coords"].shape[0])
|
| 94 |
+
return points
|
| 95 |
+
|
| 96 |
+
def get_point_cloud(self):
|
| 97 |
+
return self.point_cloud
|
| 98 |
+
|
| 99 |
+
def generate_density(self, width=256, height=256):
|
| 100 |
+
|
| 101 |
+
ps = self.point_cloud["coords"] * -1
|
| 102 |
+
ps[:, 0] *= -1
|
| 103 |
+
ps[:, 1] *= -1
|
| 104 |
+
|
| 105 |
+
pcd = o3d.geometry.PointCloud()
|
| 106 |
+
pcd.points = o3d.utility.Vector3dVector(ps)
|
| 107 |
+
pcd.estimate_normals()
|
| 108 |
+
|
| 109 |
+
# zs = np.round(ps[:,2] / 100) * 100
|
| 110 |
+
# zs, zs_ind = np.unique(zs, return_index=True, axis=0)
|
| 111 |
+
# ps_ind = ps[:, :2] ==
|
| 112 |
+
# print("Generate density...")
|
| 113 |
+
|
| 114 |
+
image_res = np.array((width, height))
|
| 115 |
+
|
| 116 |
+
max_coords = np.max(ps, axis=0)
|
| 117 |
+
min_coords = np.min(ps, axis=0)
|
| 118 |
+
max_m_min = max_coords - min_coords
|
| 119 |
+
|
| 120 |
+
max_coords = max_coords + 0.1 * max_m_min
|
| 121 |
+
min_coords = min_coords - 0.1 * max_m_min
|
| 122 |
+
|
| 123 |
+
normalization_dict = {}
|
| 124 |
+
normalization_dict["min_coords"] = min_coords
|
| 125 |
+
normalization_dict["max_coords"] = max_coords
|
| 126 |
+
normalization_dict["image_res"] = image_res
|
| 127 |
+
|
| 128 |
+
# coordinates = np.round(points[:, :2] / max_coordinates[None,:2] * image_res[None])
|
| 129 |
+
coordinates = np.round(
|
| 130 |
+
(ps[:, :2] - min_coords[None, :2]) / (max_coords[None, :2] - min_coords[None, :2]) * image_res[None]
|
| 131 |
+
)
|
| 132 |
+
coordinates = np.minimum(np.maximum(coordinates, np.zeros_like(image_res)), image_res - 1)
|
| 133 |
+
|
| 134 |
+
density = np.zeros((height, width), dtype=np.float32)
|
| 135 |
+
|
| 136 |
+
unique_coordinates, counts = np.unique(coordinates, return_counts=True, axis=0)
|
| 137 |
+
# print(np.unique(counts))
|
| 138 |
+
# counts = np.minimum(counts, 1e2)
|
| 139 |
+
|
| 140 |
+
unique_coordinates = unique_coordinates.astype(np.int32)
|
| 141 |
+
|
| 142 |
+
density[unique_coordinates[:, 1], unique_coordinates[:, 0]] = counts
|
| 143 |
+
density = density / np.max(density)
|
| 144 |
+
# print(np.unique(density))
|
| 145 |
+
|
| 146 |
+
normals = np.array(pcd.normals)
|
| 147 |
+
normals_map = np.zeros((density.shape[0], density.shape[1], 3))
|
| 148 |
+
|
| 149 |
+
import time
|
| 150 |
+
|
| 151 |
+
start_time = time.time()
|
| 152 |
+
for i, unique_coord in enumerate(unique_coordinates):
|
| 153 |
+
# print(normals[unique_ind])
|
| 154 |
+
normals_indcs = np.argwhere(np.all(coordinates[::10] == unique_coord, axis=1))[:, 0]
|
| 155 |
+
normals_map[unique_coordinates[i, 1], unique_coordinates[i, 0], :] = np.mean(
|
| 156 |
+
normals[::10][normals_indcs, :], axis=0
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
print("Time for normals: ", time.time() - start_time)
|
| 160 |
+
|
| 161 |
+
normals_map = (np.clip(normals_map, 0, 1) * 255).astype(np.uint8)
|
| 162 |
+
|
| 163 |
+
# plt.figure()
|
| 164 |
+
# plt.imshow(normals_map)
|
| 165 |
+
# plt.show()
|
| 166 |
+
|
| 167 |
+
return density, normals_map, normalization_dict
|
| 168 |
+
|
| 169 |
+
def visualize(self, export_path=None):
|
| 170 |
+
pcd = o3d.geometry.PointCloud()
|
| 171 |
+
|
| 172 |
+
points = self.point_cloud["coords"]
|
| 173 |
+
|
| 174 |
+
print(np.max(points, axis=0))
|
| 175 |
+
indices = np.where(points[:, 2] < 2000)
|
| 176 |
+
|
| 177 |
+
points = points[indices]
|
| 178 |
+
points[:, 1] *= -1
|
| 179 |
+
points[:, :] /= 1000
|
| 180 |
+
pcd.points = o3d.utility.Vector3dVector(points)
|
| 181 |
+
|
| 182 |
+
if self.generate_normal:
|
| 183 |
+
normals = self.point_cloud["normals"]
|
| 184 |
+
normals = normals[indices]
|
| 185 |
+
pcd.normals = o3d.utility.Vector3dVector(normals)
|
| 186 |
+
if self.generate_color:
|
| 187 |
+
colors = self.point_cloud["colors"]
|
| 188 |
+
colors = colors[indices]
|
| 189 |
+
pcd.colors = o3d.utility.Vector3dVector(colors)
|
| 190 |
+
|
| 191 |
+
# wireframe_geo_list = visualize_wireframe(annos, vis=False, ret=True)
|
| 192 |
+
# o3d.visualization.draw_geometries([pcd] + wireframe_geo_list)
|
| 193 |
+
# o3d.visualization.draw_geometries([pcd])
|
| 194 |
+
|
| 195 |
+
pcd.estimate_normals()
|
| 196 |
+
|
| 197 |
+
# radii = 0.01
|
| 198 |
+
# mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting(pcd, radii)
|
| 199 |
+
|
| 200 |
+
# alpha = 0.1
|
| 201 |
+
# tetra_mesh, pt_map = o3d.geometry.TetraMesh.create_from_point_cloud(pcd)
|
| 202 |
+
# mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_alpha_shape(pcd, alpha, tetra_mesh, pt_map)
|
| 203 |
+
|
| 204 |
+
o3d.visualization.draw_geometries([pcd])
|
| 205 |
+
|
| 206 |
+
if export_path is not None:
|
| 207 |
+
o3d.io.write_point_cloud(export_path, pcd)
|
| 208 |
+
|
| 209 |
+
# o3d.visualization.draw_geometries([pcd])
|
| 210 |
+
|
| 211 |
+
def export_ply(self, path):
|
| 212 |
+
"""
|
| 213 |
+
ply
|
| 214 |
+
format ascii 1.0
|
| 215 |
+
comment Mars model by Paul Bourke
|
| 216 |
+
element vertex 259200
|
| 217 |
+
property float x
|
| 218 |
+
property float y
|
| 219 |
+
property float z
|
| 220 |
+
property uchar r
|
| 221 |
+
property uchar g
|
| 222 |
+
property uchar b
|
| 223 |
+
property float nx
|
| 224 |
+
property float ny
|
| 225 |
+
property float nz
|
| 226 |
+
end_header
|
| 227 |
+
"""
|
| 228 |
+
with open(path, "w") as f:
|
| 229 |
+
f.write("ply\n")
|
| 230 |
+
f.write("format ascii 1.0\n")
|
| 231 |
+
f.write("element vertex %d\n" % self.point_cloud["coords"].shape[0])
|
| 232 |
+
f.write("property float x\n")
|
| 233 |
+
f.write("property float y\n")
|
| 234 |
+
f.write("property float z\n")
|
| 235 |
+
if self.generate_color:
|
| 236 |
+
f.write("property uchar red\n")
|
| 237 |
+
f.write("property uchar green\n")
|
| 238 |
+
f.write("property uchar blue\n")
|
| 239 |
+
if self.generate_normal:
|
| 240 |
+
f.write("property float nx\n")
|
| 241 |
+
f.write("property float ny\n")
|
| 242 |
+
f.write("property float nz\n")
|
| 243 |
+
f.write("end_header\n")
|
| 244 |
+
for i in range(self.point_cloud["coords"].shape[0]):
|
| 245 |
+
normal = []
|
| 246 |
+
color = []
|
| 247 |
+
coord = self.point_cloud["coords"][i].tolist()
|
| 248 |
+
if self.generate_color:
|
| 249 |
+
color = list(map(int, (self.point_cloud["colors"][i] * 255).tolist()))
|
| 250 |
+
if self.generate_normal:
|
| 251 |
+
normal = self.point_cloud["normals"][i].tolist()
|
| 252 |
+
data = coord + color + normal
|
| 253 |
+
f.write(" ".join(list(map(str, data))) + "\n")
|
data_preprocess/stru3d/generate_coco_stru3d.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
|
| 6 |
+
from stru3d_utils import generate_coco_dict, generate_density, normalize_annotations, parse_floor_plan_polys
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
|
| 9 |
+
sys.path.append("../.")
|
| 10 |
+
from common_utils import export_density, read_scene_pc
|
| 11 |
+
|
| 12 |
+
### Note: Some scenes have missing/wrong annotations. These are the indices that you should additionally exclude
|
| 13 |
+
### to be consistent with MonteFloor and HEAT:
|
| 14 |
+
invalid_scenes_ids = [
|
| 15 |
+
76,
|
| 16 |
+
183,
|
| 17 |
+
335,
|
| 18 |
+
491,
|
| 19 |
+
663,
|
| 20 |
+
681,
|
| 21 |
+
703,
|
| 22 |
+
728,
|
| 23 |
+
865,
|
| 24 |
+
936,
|
| 25 |
+
985,
|
| 26 |
+
986,
|
| 27 |
+
1009,
|
| 28 |
+
1104,
|
| 29 |
+
1155,
|
| 30 |
+
1221,
|
| 31 |
+
1282,
|
| 32 |
+
1365,
|
| 33 |
+
1378,
|
| 34 |
+
1635,
|
| 35 |
+
1745,
|
| 36 |
+
1772,
|
| 37 |
+
1774,
|
| 38 |
+
1816,
|
| 39 |
+
1866,
|
| 40 |
+
2037,
|
| 41 |
+
2076,
|
| 42 |
+
2274,
|
| 43 |
+
2334,
|
| 44 |
+
2357,
|
| 45 |
+
2580,
|
| 46 |
+
2665,
|
| 47 |
+
2706,
|
| 48 |
+
2713,
|
| 49 |
+
2771,
|
| 50 |
+
2868,
|
| 51 |
+
3156,
|
| 52 |
+
3192,
|
| 53 |
+
3198,
|
| 54 |
+
3261,
|
| 55 |
+
3271,
|
| 56 |
+
3276,
|
| 57 |
+
3296,
|
| 58 |
+
3342,
|
| 59 |
+
3387,
|
| 60 |
+
3398,
|
| 61 |
+
3466,
|
| 62 |
+
3496,
|
| 63 |
+
]
|
| 64 |
+
|
| 65 |
+
type2id = {
|
| 66 |
+
"living room": 0,
|
| 67 |
+
"kitchen": 1,
|
| 68 |
+
"bedroom": 2,
|
| 69 |
+
"bathroom": 3,
|
| 70 |
+
"balcony": 4,
|
| 71 |
+
"corridor": 5,
|
| 72 |
+
"dining room": 6,
|
| 73 |
+
"study": 7,
|
| 74 |
+
"studio": 8,
|
| 75 |
+
"store room": 9,
|
| 76 |
+
"garden": 10,
|
| 77 |
+
"laundry room": 11,
|
| 78 |
+
"office": 12,
|
| 79 |
+
"basement": 13,
|
| 80 |
+
"garage": 14,
|
| 81 |
+
"undefined": 15,
|
| 82 |
+
"door": 16,
|
| 83 |
+
"window": 17,
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def config():
|
| 88 |
+
a = argparse.ArgumentParser(description="Generate coco format data for Structured3D")
|
| 89 |
+
a.add_argument(
|
| 90 |
+
"--data_root", default="Structured3D_panorama", type=str, help="path to raw Structured3D_panorama folder"
|
| 91 |
+
)
|
| 92 |
+
a.add_argument("--output", default="coco_stru3d", type=str, help="path to output folder")
|
| 93 |
+
|
| 94 |
+
args = a.parse_args()
|
| 95 |
+
return args
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def main(args):
|
| 99 |
+
data_root = args.data_root
|
| 100 |
+
data_parts = os.listdir(data_root)
|
| 101 |
+
|
| 102 |
+
### prepare
|
| 103 |
+
outFolder = args.output
|
| 104 |
+
if not os.path.exists(outFolder):
|
| 105 |
+
os.mkdir(outFolder)
|
| 106 |
+
|
| 107 |
+
annotation_outFolder = os.path.join(outFolder, "annotations")
|
| 108 |
+
if not os.path.exists(annotation_outFolder):
|
| 109 |
+
os.mkdir(annotation_outFolder)
|
| 110 |
+
|
| 111 |
+
train_img_folder = os.path.join(outFolder, "train")
|
| 112 |
+
val_img_folder = os.path.join(outFolder, "val")
|
| 113 |
+
test_img_folder = os.path.join(outFolder, "test")
|
| 114 |
+
|
| 115 |
+
for img_folder in [train_img_folder, val_img_folder, test_img_folder]:
|
| 116 |
+
if not os.path.exists(img_folder):
|
| 117 |
+
os.mkdir(img_folder)
|
| 118 |
+
|
| 119 |
+
coco_train_json_path = os.path.join(annotation_outFolder, "train.json")
|
| 120 |
+
coco_val_json_path = os.path.join(annotation_outFolder, "val.json")
|
| 121 |
+
coco_test_json_path = os.path.join(annotation_outFolder, "test.json")
|
| 122 |
+
|
| 123 |
+
coco_train_dict = {"images": [], "annotations": [], "categories": []}
|
| 124 |
+
coco_val_dict = {"images": [], "annotations": [], "categories": []}
|
| 125 |
+
coco_test_dict = {"images": [], "annotations": [], "categories": []}
|
| 126 |
+
|
| 127 |
+
for key, value in type2id.items():
|
| 128 |
+
type_dict = {"supercategory": "room", "id": value, "name": key}
|
| 129 |
+
coco_train_dict["categories"].append(type_dict)
|
| 130 |
+
coco_val_dict["categories"].append(type_dict)
|
| 131 |
+
coco_test_dict["categories"].append(type_dict)
|
| 132 |
+
|
| 133 |
+
### begin processing
|
| 134 |
+
instance_id = 0
|
| 135 |
+
for part in tqdm(data_parts):
|
| 136 |
+
scenes = os.listdir(os.path.join(data_root, part, "Structured3D"))
|
| 137 |
+
for scene in tqdm(scenes):
|
| 138 |
+
scene_path = os.path.join(data_root, part, "Structured3D", scene)
|
| 139 |
+
scene_id = scene.split("_")[-1]
|
| 140 |
+
|
| 141 |
+
if int(scene_id) in invalid_scenes_ids:
|
| 142 |
+
print("skip {}".format(scene))
|
| 143 |
+
continue
|
| 144 |
+
|
| 145 |
+
# load pre-generated point cloud
|
| 146 |
+
ply_path = os.path.join(scene_path, "point_cloud.ply")
|
| 147 |
+
points = read_scene_pc(ply_path)
|
| 148 |
+
xyz = points[:, :3]
|
| 149 |
+
|
| 150 |
+
### project point cloud to density map
|
| 151 |
+
density, normalization_dict = generate_density(xyz, width=256, height=256)
|
| 152 |
+
|
| 153 |
+
### rescale raw annotations
|
| 154 |
+
normalized_annos = normalize_annotations(scene_path, normalization_dict)
|
| 155 |
+
|
| 156 |
+
### prepare coco dict
|
| 157 |
+
img_id = int(scene_id)
|
| 158 |
+
img_dict = {}
|
| 159 |
+
img_dict["file_name"] = scene_id + ".png"
|
| 160 |
+
img_dict["id"] = img_id
|
| 161 |
+
img_dict["width"] = 256
|
| 162 |
+
img_dict["height"] = 256
|
| 163 |
+
|
| 164 |
+
### parse annotations
|
| 165 |
+
polys = parse_floor_plan_polys(normalized_annos)
|
| 166 |
+
polygons_list = generate_coco_dict(normalized_annos, polys, instance_id, img_id, ignore_types=["outwall"])
|
| 167 |
+
|
| 168 |
+
instance_id += len(polygons_list)
|
| 169 |
+
|
| 170 |
+
### train
|
| 171 |
+
if int(scene_id) < 3000:
|
| 172 |
+
coco_train_dict["images"].append(img_dict)
|
| 173 |
+
coco_train_dict["annotations"] += polygons_list
|
| 174 |
+
export_density(density, train_img_folder, scene_id)
|
| 175 |
+
|
| 176 |
+
### val
|
| 177 |
+
elif int(scene_id) >= 3000 and int(scene_id) < 3250:
|
| 178 |
+
coco_val_dict["images"].append(img_dict)
|
| 179 |
+
coco_val_dict["annotations"] += polygons_list
|
| 180 |
+
export_density(density, val_img_folder, scene_id)
|
| 181 |
+
|
| 182 |
+
### test
|
| 183 |
+
else:
|
| 184 |
+
coco_test_dict["images"].append(img_dict)
|
| 185 |
+
coco_test_dict["annotations"] += polygons_list
|
| 186 |
+
export_density(density, test_img_folder, scene_id)
|
| 187 |
+
|
| 188 |
+
print(scene_id)
|
| 189 |
+
|
| 190 |
+
with open(coco_train_json_path, "w") as f:
|
| 191 |
+
json.dump(coco_train_dict, f)
|
| 192 |
+
with open(coco_val_json_path, "w") as f:
|
| 193 |
+
json.dump(coco_val_dict, f)
|
| 194 |
+
with open(coco_test_json_path, "w") as f:
|
| 195 |
+
json.dump(coco_test_dict, f)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
if __name__ == "__main__":
|
| 199 |
+
main(config())
|
data_preprocess/stru3d/generate_point_cloud_stru3d.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
from PointCloudReaderPanorama import PointCloudReaderPanorama
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def config():
|
| 9 |
+
a = argparse.ArgumentParser(description="Generate point cloud for Structured3D")
|
| 10 |
+
a.add_argument(
|
| 11 |
+
"--data_root", default="Structured3D_panorama", type=str, help="path to raw Structured3D_panorama folder"
|
| 12 |
+
)
|
| 13 |
+
args = a.parse_args()
|
| 14 |
+
return args
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def main(args):
|
| 18 |
+
print("Creating point cloud from perspective views...")
|
| 19 |
+
data_root = args.data_root
|
| 20 |
+
data_parts = os.listdir(data_root)
|
| 21 |
+
|
| 22 |
+
for part in tqdm(data_parts):
|
| 23 |
+
scenes = os.listdir(os.path.join(data_root, part, "Structured3D"))
|
| 24 |
+
for scene in tqdm(scenes):
|
| 25 |
+
scene_path = os.path.join(data_root, part, "Structured3D", scene)
|
| 26 |
+
reader = PointCloudReaderPanorama(scene_path, random_level=0, generate_color=True, generate_normal=False)
|
| 27 |
+
save_path = os.path.join(data_root, part, "Structured3D", scene, "point_cloud.ply")
|
| 28 |
+
reader.export_ply(save_path)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
if __name__ == "__main__":
|
| 32 |
+
main(config())
|
data_preprocess/stru3d/stru3d_utils.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This code is an adaptation that uses Structured 3D for the code base.
|
| 3 |
+
|
| 4 |
+
Reference: https://github.com/bertjiazheng/Structured3D
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
from shapely.geometry import Polygon
|
| 13 |
+
|
| 14 |
+
sys.path.append("../data_preprocess")
|
| 15 |
+
from common_utils import resort_corners
|
| 16 |
+
|
| 17 |
+
type2id = {
|
| 18 |
+
"living room": 0,
|
| 19 |
+
"kitchen": 1,
|
| 20 |
+
"bedroom": 2,
|
| 21 |
+
"bathroom": 3,
|
| 22 |
+
"balcony": 4,
|
| 23 |
+
"corridor": 5,
|
| 24 |
+
"dining room": 6,
|
| 25 |
+
"study": 7,
|
| 26 |
+
"studio": 8,
|
| 27 |
+
"store room": 9,
|
| 28 |
+
"garden": 10,
|
| 29 |
+
"laundry room": 11,
|
| 30 |
+
"office": 12,
|
| 31 |
+
"basement": 13,
|
| 32 |
+
"garage": 14,
|
| 33 |
+
"undefined": 15,
|
| 34 |
+
"door": 16,
|
| 35 |
+
"window": 17,
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def generate_density(point_cloud, width=256, height=256):
|
| 40 |
+
|
| 41 |
+
ps = point_cloud * -1
|
| 42 |
+
ps[:, 0] *= -1
|
| 43 |
+
ps[:, 1] *= -1
|
| 44 |
+
|
| 45 |
+
image_res = np.array((width, height))
|
| 46 |
+
|
| 47 |
+
max_coords = np.max(ps, axis=0)
|
| 48 |
+
min_coords = np.min(ps, axis=0)
|
| 49 |
+
max_m_min = max_coords - min_coords
|
| 50 |
+
|
| 51 |
+
max_coords = max_coords + 0.1 * max_m_min
|
| 52 |
+
min_coords = min_coords - 0.1 * max_m_min
|
| 53 |
+
|
| 54 |
+
normalization_dict = {}
|
| 55 |
+
normalization_dict["min_coords"] = min_coords
|
| 56 |
+
normalization_dict["max_coords"] = max_coords
|
| 57 |
+
normalization_dict["image_res"] = image_res
|
| 58 |
+
|
| 59 |
+
# coordinates = np.round(points[:, :2] / max_coordinates[None,:2] * image_res[None])
|
| 60 |
+
coordinates = np.round(
|
| 61 |
+
(ps[:, :2] - min_coords[None, :2]) / (max_coords[None, :2] - min_coords[None, :2]) * image_res[None]
|
| 62 |
+
)
|
| 63 |
+
coordinates = np.minimum(np.maximum(coordinates, np.zeros_like(image_res)), image_res - 1)
|
| 64 |
+
|
| 65 |
+
density = np.zeros((height, width), dtype=np.float32)
|
| 66 |
+
|
| 67 |
+
unique_coordinates, counts = np.unique(coordinates, return_counts=True, axis=0)
|
| 68 |
+
# print(np.unique(counts))
|
| 69 |
+
# counts = np.minimum(counts, 1e2)
|
| 70 |
+
|
| 71 |
+
unique_coordinates = unique_coordinates.astype(np.int32)
|
| 72 |
+
|
| 73 |
+
density[unique_coordinates[:, 1], unique_coordinates[:, 0]] = counts
|
| 74 |
+
density = density / np.max(density)
|
| 75 |
+
|
| 76 |
+
return density, normalization_dict
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def normalize_point(point, normalization_dict):
|
| 80 |
+
|
| 81 |
+
min_coords = normalization_dict["min_coords"]
|
| 82 |
+
max_coords = normalization_dict["max_coords"]
|
| 83 |
+
image_res = normalization_dict["image_res"]
|
| 84 |
+
|
| 85 |
+
point_2d = np.round((point[:2] - min_coords[:2]) / (max_coords[:2] - min_coords[:2]) * image_res)
|
| 86 |
+
point_2d = np.minimum(np.maximum(point_2d, np.zeros_like(image_res)), image_res - 1)
|
| 87 |
+
|
| 88 |
+
point[:2] = point_2d.tolist()
|
| 89 |
+
|
| 90 |
+
return point
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def normalize_annotations(scene_path, normalization_dict):
|
| 94 |
+
annotation_path = os.path.join(scene_path, "annotation_3d.json")
|
| 95 |
+
with open(annotation_path, "r") as f:
|
| 96 |
+
annotation_json = json.load(f)
|
| 97 |
+
|
| 98 |
+
for line in annotation_json["lines"]:
|
| 99 |
+
point = line["point"]
|
| 100 |
+
point = normalize_point(point, normalization_dict)
|
| 101 |
+
line["point"] = point
|
| 102 |
+
|
| 103 |
+
for junction in annotation_json["junctions"]:
|
| 104 |
+
point = junction["coordinate"]
|
| 105 |
+
point = normalize_point(point, normalization_dict)
|
| 106 |
+
junction["coordinate"] = point
|
| 107 |
+
|
| 108 |
+
return annotation_json
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def parse_floor_plan_polys(annos):
|
| 112 |
+
planes = []
|
| 113 |
+
for semantic in annos["semantics"]:
|
| 114 |
+
for planeID in semantic["planeID"]:
|
| 115 |
+
if annos["planes"][planeID]["type"] == "floor":
|
| 116 |
+
planes.append({"planeID": planeID, "type": semantic["type"]})
|
| 117 |
+
|
| 118 |
+
# if semantic["type"] == "outwall":
|
| 119 |
+
# outerwall_planes = semantic["planeID"]
|
| 120 |
+
|
| 121 |
+
# extract hole vertices
|
| 122 |
+
lines_holes = []
|
| 123 |
+
for semantic in annos["semantics"]:
|
| 124 |
+
if semantic["type"] in ["window", "door"]:
|
| 125 |
+
for planeID in semantic["planeID"]:
|
| 126 |
+
lines_holes.extend(np.where(np.array(annos["planeLineMatrix"][planeID]))[0].tolist())
|
| 127 |
+
lines_holes = np.unique(lines_holes)
|
| 128 |
+
|
| 129 |
+
## junctions on the floor
|
| 130 |
+
# junctions = np.array([junc["coordinate"] for junc in annos["junctions"]])
|
| 131 |
+
|
| 132 |
+
# construct each polygon
|
| 133 |
+
polygons = []
|
| 134 |
+
for plane in planes:
|
| 135 |
+
lineIDs = np.where(np.array(annos["planeLineMatrix"][plane["planeID"]]))[0].tolist()
|
| 136 |
+
junction_pairs = [np.where(np.array(annos["lineJunctionMatrix"][lineID]))[0].tolist() for lineID in lineIDs]
|
| 137 |
+
polygon = convert_lines_to_vertices(junction_pairs)
|
| 138 |
+
polygons.append([polygon[0], plane["type"]])
|
| 139 |
+
|
| 140 |
+
return polygons
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def convert_lines_to_vertices(lines):
|
| 144 |
+
"""
|
| 145 |
+
convert line representation to polygon vertices
|
| 146 |
+
|
| 147 |
+
"""
|
| 148 |
+
polygons = []
|
| 149 |
+
lines = np.array(lines)
|
| 150 |
+
|
| 151 |
+
polygon = None
|
| 152 |
+
while len(lines) != 0:
|
| 153 |
+
if polygon is None:
|
| 154 |
+
polygon = lines[0].tolist()
|
| 155 |
+
lines = np.delete(lines, 0, 0)
|
| 156 |
+
|
| 157 |
+
lineID, juncID = np.where(lines == polygon[-1])
|
| 158 |
+
vertex = lines[lineID[0], 1 - juncID[0]]
|
| 159 |
+
lines = np.delete(lines, lineID, 0)
|
| 160 |
+
|
| 161 |
+
if vertex in polygon:
|
| 162 |
+
polygons.append(polygon)
|
| 163 |
+
polygon = None
|
| 164 |
+
else:
|
| 165 |
+
polygon.append(vertex)
|
| 166 |
+
|
| 167 |
+
return polygons
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def generate_coco_dict(annos, polygons, curr_instance_id, curr_img_id, ignore_types):
|
| 171 |
+
|
| 172 |
+
junctions = np.array([junc["coordinate"][:2] for junc in annos["junctions"]])
|
| 173 |
+
|
| 174 |
+
coco_annotation_dict_list = []
|
| 175 |
+
|
| 176 |
+
for poly_ind, (polygon, poly_type) in enumerate(polygons):
|
| 177 |
+
if poly_type in ignore_types:
|
| 178 |
+
continue
|
| 179 |
+
|
| 180 |
+
polygon = junctions[np.array(polygon)]
|
| 181 |
+
|
| 182 |
+
poly_shapely = Polygon(polygon)
|
| 183 |
+
area = poly_shapely.area
|
| 184 |
+
|
| 185 |
+
# assert area > 10
|
| 186 |
+
# if area < 100:
|
| 187 |
+
if poly_type not in ["door", "window"] and area < 100:
|
| 188 |
+
continue
|
| 189 |
+
if poly_type in ["door", "window"] and area < 1:
|
| 190 |
+
continue
|
| 191 |
+
|
| 192 |
+
rectangle_shapely = poly_shapely.envelope
|
| 193 |
+
|
| 194 |
+
### here we convert door/window annotation into a single line
|
| 195 |
+
if poly_type in ["door", "window"]:
|
| 196 |
+
assert polygon.shape[0] == 4
|
| 197 |
+
midp_1 = (polygon[0] + polygon[1]) / 2
|
| 198 |
+
midp_2 = (polygon[1] + polygon[2]) / 2
|
| 199 |
+
midp_3 = (polygon[2] + polygon[3]) / 2
|
| 200 |
+
midp_4 = (polygon[3] + polygon[0]) / 2
|
| 201 |
+
|
| 202 |
+
dist_1_3 = np.square(midp_1 - midp_3).sum()
|
| 203 |
+
dist_2_4 = np.square(midp_2 - midp_4).sum()
|
| 204 |
+
if dist_1_3 > dist_2_4:
|
| 205 |
+
polygon = np.row_stack([midp_1, midp_3])
|
| 206 |
+
else:
|
| 207 |
+
polygon = np.row_stack([midp_2, midp_4])
|
| 208 |
+
|
| 209 |
+
coco_seg_poly = []
|
| 210 |
+
poly_sorted = resort_corners(polygon)
|
| 211 |
+
|
| 212 |
+
for p in poly_sorted:
|
| 213 |
+
coco_seg_poly += list(p)
|
| 214 |
+
|
| 215 |
+
# Slightly wider bounding box
|
| 216 |
+
bound_pad = 2
|
| 217 |
+
bb_x, bb_y = rectangle_shapely.exterior.xy
|
| 218 |
+
bb_x = np.unique(bb_x)
|
| 219 |
+
bb_y = np.unique(bb_y)
|
| 220 |
+
bb_x_min = np.maximum(np.min(bb_x) - bound_pad, 0)
|
| 221 |
+
bb_y_min = np.maximum(np.min(bb_y) - bound_pad, 0)
|
| 222 |
+
|
| 223 |
+
bb_x_max = np.minimum(np.max(bb_x) + bound_pad, 256 - 1)
|
| 224 |
+
bb_y_max = np.minimum(np.max(bb_y) + bound_pad, 256 - 1)
|
| 225 |
+
|
| 226 |
+
bb_width = bb_x_max - bb_x_min
|
| 227 |
+
bb_height = bb_y_max - bb_y_min
|
| 228 |
+
|
| 229 |
+
coco_bb = [bb_x_min, bb_y_min, bb_width, bb_height]
|
| 230 |
+
|
| 231 |
+
coco_annotation_dict = {
|
| 232 |
+
"segmentation": [coco_seg_poly],
|
| 233 |
+
"area": area,
|
| 234 |
+
"iscrowd": 0,
|
| 235 |
+
"image_id": curr_img_id,
|
| 236 |
+
"bbox": coco_bb,
|
| 237 |
+
"category_id": type2id[poly_type],
|
| 238 |
+
"id": curr_instance_id,
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
coco_annotation_dict_list.append(coco_annotation_dict)
|
| 242 |
+
curr_instance_id += 1
|
| 243 |
+
|
| 244 |
+
return coco_annotation_dict_list
|
data_preprocess/tools/plot_data.sh
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
|
| 3 |
+
# Additional useful arguments:
|
| 4 |
+
# --crop_white_space: remove redundant whitespace from the rendering
|
| 5 |
+
# --one_color: use single color for every room (i.e. yellow)
|
| 6 |
+
# --compute_stats: compute statistics of the dataset (e.g. max_num_pts, max_num_polys)
|
| 7 |
+
# and plot histogram for counting number of Points, Rooms, Corners
|
| 8 |
+
# --drop_wd: disable Windor & Door in the plots
|
| 9 |
+
# --image_scale: adjust rendering resolution of the plots
|
| 10 |
+
|
| 11 |
+
SPLIT=test
|
| 12 |
+
python plot_floor.py --dataset_name=stru3d \
|
| 13 |
+
--dataset_root=data/coco_s3d_bw/ \
|
| 14 |
+
--eval_set=${SPLIT} \
|
| 15 |
+
--output_dir=data_plots/output_gt_s3dbw/${SPLIT} \
|
| 16 |
+
--semantic_classes=19 \
|
| 17 |
+
--input_channels 3 \
|
| 18 |
+
--disable_image_transform \
|
| 19 |
+
--poly2seq \
|
| 20 |
+
--image_size 256 \
|
| 21 |
+
--image_scale 1 \
|
| 22 |
+
--compute_stats \
|
| 23 |
+
--plot_gt \
|
| 24 |
+
--plot_gt_image \
|
| 25 |
+
--plot_polys \
|
| 26 |
+
--plot_density
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
SPLIT=test
|
| 30 |
+
python plot_floor.py --dataset_name=r2g \
|
| 31 |
+
--dataset_root=data/R2G_hr_dataset_processed_v1/ \
|
| 32 |
+
--eval_set=${SPLIT} \
|
| 33 |
+
--output_dir=output_gt_r2g/${SPLIT} \
|
| 34 |
+
--semantic_classes=13 \
|
| 35 |
+
--input_channels 3 \
|
| 36 |
+
--poly2seq \
|
| 37 |
+
--disable_image_transform \
|
| 38 |
+
--image_size 256 \
|
| 39 |
+
--image_scale 1 \
|
| 40 |
+
--compute_stats \
|
| 41 |
+
--plot_gt \
|
| 42 |
+
--plot_polys \
|
| 43 |
+
--plot_density
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
SPLIT=test
|
| 47 |
+
python plot_floor.py --dataset_name=cubicasa \
|
| 48 |
+
--dataset_root=data/coco_cubicasa5k_nowalls_v4-1_refined \
|
| 49 |
+
--eval_set=${SPLIT} \
|
| 50 |
+
--output_dir=data_plots/output_gt_cc5k/${SPLIT} \
|
| 51 |
+
--semantic_classes=12 \
|
| 52 |
+
--input_channels 3 \
|
| 53 |
+
--disable_image_transform \
|
| 54 |
+
--poly2seq \
|
| 55 |
+
--image_size 256 \
|
| 56 |
+
--image_scale 1 \
|
| 57 |
+
--compute_stats \
|
| 58 |
+
--plot_gt \
|
| 59 |
+
--plot_polys \
|
| 60 |
+
--plot_density
|
data_preprocess/tools/run_cc5k.sh
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# create COCO-style dataset for CubiCasa5k
|
| 2 |
+
python -m data_preprocess.cubicasa5k.create_coco_cc5k --data_root=data/cubicasa5k/ \
|
| 3 |
+
--output=data/coco_cubicasa5k_nowalls_v4/ \
|
| 4 |
+
--disable_wd2line
|
| 5 |
+
|
| 6 |
+
# Split example has more than 1 floorplan into separate samples
|
| 7 |
+
python -m data_preprocess.cubicasa5k.create_coco_cc5k.floorplan_extraction \
|
| 8 |
+
--data_root data/coco_cubicasa5k_nowalls_v4/ \
|
| 9 |
+
--output data/coco_cubicasa5k_nowalls_v4-1_refined/
|
| 10 |
+
|
| 11 |
+
# Merge individual JSONs into single JSON file per split (train/val/test)
|
| 12 |
+
# This must be done after floorplan_extraction.py
|
| 13 |
+
python -m data_preprocess.cubicasa5k.combine_json \
|
| 14 |
+
--input data/coco_cubicasa5k_nowalls_v4-1_refined/ \
|
| 15 |
+
--output data/coco_cubicasa5k_nowalls_v4-1_refined/annotations/ \
|
data_preprocess/tools/run_r2g.sh
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# preprocess raw Raster2Graph high-resolution dataset
|
| 2 |
+
python -m data_preprocess.raster2graph.image_process --data_root=data/R2G_hr_dataset/
|
| 3 |
+
|
| 4 |
+
# convert to COCO-style dataset
|
| 5 |
+
python -m data_preprocess.raster2graph.convert_to_coco --dataset_path data/R2G_hr_dataset/ --output_dir data/R2G_hr_dataset_processed/
|
| 6 |
+
|
| 7 |
+
# combine JSON files into single JSON file per split
|
| 8 |
+
python -m data_preprocess.raster2graph.combine_json \
|
| 9 |
+
--input data/R2G_hr_dataset_processed/ \
|
| 10 |
+
--output data/R2G_hr_dataset_processed_v1/ \
|
| 11 |
+
|
| 12 |
+
rm -rf data/R2G_hr_dataset_processed/
|
data_preprocess/tools/run_s3d.sh
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Assume the Structured3D density dataset are downloaded
|
| 2 |
+
DATA=data/coco_s3d
|
| 3 |
+
|
| 4 |
+
for split in train val test; do
|
| 5 |
+
python plot_floor.py --dataset_name=stru3d \
|
| 6 |
+
--dataset_root=${DATA} \
|
| 7 |
+
--eval_set=${split} \
|
| 8 |
+
--output_dir=data/coco_s3d_bw/${split}/ \
|
| 9 |
+
--semantic_classes=19 \
|
| 10 |
+
--input_channels 3 \
|
| 11 |
+
--disable_image_transform \
|
| 12 |
+
--poly2seq \
|
| 13 |
+
--image_size 256 \
|
| 14 |
+
--image_scale 1 \
|
| 15 |
+
--plot_gt \
|
| 16 |
+
--is_bw \
|
| 17 |
+
--plot_engine matplotlib
|
| 18 |
+
|
| 19 |
+
done
|
| 20 |
+
|
| 21 |
+
# Reuse the annotations
|
| 22 |
+
cp -r data/coco_s3d/annotations data/coco_s3d_bw/
|
data_preprocess/tools/run_waffle.sh
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
python -m data_preprocess.waffle.create_coco_waffle_benchmark \
|
| 2 |
+
--data_root data/waffle/benchmark/ \
|
| 3 |
+
--output data/waffle_benchmark_processed/
|
data_preprocess/waffle/create_coco_waffle_benchmark.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
from glob import glob
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
import cv2
|
| 9 |
+
import numpy as np
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from shapely.geometry import Polygon
|
| 12 |
+
|
| 13 |
+
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
| 14 |
+
from common_utils import resort_corners
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def draw_polygon_on_image(image, polygons, class_to_color):
|
| 18 |
+
"""
|
| 19 |
+
Draws polygons on the image based on the COLOR_TO_CLASS mapping.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
image (numpy.ndarray): The image on which to draw.
|
| 23 |
+
polygons (list of list of tuple): List of polygons, where each polygon is a list of (x, y) points.
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
numpy.ndarray: The image with polygons drawn.
|
| 27 |
+
"""
|
| 28 |
+
# Draw each polygon on the image
|
| 29 |
+
for polygon, polygon_class in polygons:
|
| 30 |
+
# Convert polygon points to numpy array
|
| 31 |
+
pts = np.array(polygon, dtype=np.int32).reshape(-1, 2)
|
| 32 |
+
color = class_to_color[polygon_class]
|
| 33 |
+
bgr = (color[2], color[1], color[0]) # Convert RGB to BGR for OpenCV
|
| 34 |
+
# Draw filled polygon
|
| 35 |
+
cv2.fillPoly(image, [pts], bgr)
|
| 36 |
+
|
| 37 |
+
return image
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def fill_mask(segmentation_mask):
|
| 41 |
+
filled_mask = np.zeros_like(segmentation_mask, dtype=np.uint8)
|
| 42 |
+
|
| 43 |
+
# Iterate over each class index in the segmentation mask
|
| 44 |
+
for class_index in np.unique(segmentation_mask):
|
| 45 |
+
if class_index == 0: # Skip the background
|
| 46 |
+
continue
|
| 47 |
+
|
| 48 |
+
# Create a binary mask for the current class
|
| 49 |
+
binary_mask = (segmentation_mask == class_index).astype(np.uint8)
|
| 50 |
+
|
| 51 |
+
# Find contours for the current class
|
| 52 |
+
contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 53 |
+
|
| 54 |
+
# Fill each contour with white color in the single-channel mask
|
| 55 |
+
cv2.drawContours(filled_mask, contours, -1, 255, thickness=cv2.FILLED)
|
| 56 |
+
|
| 57 |
+
return filled_mask
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def to_bw_image(input_image):
|
| 61 |
+
# Convert the input image to grayscale
|
| 62 |
+
gray_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2GRAY)
|
| 63 |
+
|
| 64 |
+
# Apply a binary threshold to convert the grayscale image to black and white
|
| 65 |
+
_, bw_image = cv2.threshold(gray_image, 127, 255, cv2.THRESH_BINARY)
|
| 66 |
+
return bw_image
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def create_coco_bounding_box(bb_x, bb_y, image_width, image_height, bound_pad=2):
|
| 70 |
+
bb_x = np.unique(bb_x)
|
| 71 |
+
bb_y = np.unique(bb_y)
|
| 72 |
+
bb_x_min = np.maximum(np.min(bb_x) - bound_pad, 0)
|
| 73 |
+
bb_y_min = np.maximum(np.min(bb_y) - bound_pad, 0)
|
| 74 |
+
|
| 75 |
+
bb_x_max = np.minimum(np.max(bb_x) + bound_pad, image_width - 1)
|
| 76 |
+
bb_y_max = np.minimum(np.max(bb_y) + bound_pad, image_height - 1)
|
| 77 |
+
|
| 78 |
+
bb_width = bb_x_max - bb_x_min
|
| 79 |
+
bb_height = bb_y_max - bb_y_min
|
| 80 |
+
|
| 81 |
+
coco_bb = [bb_x_min, bb_y_min, bb_width, bb_height]
|
| 82 |
+
return coco_bb
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def prepare_dict(categories_dict):
|
| 86 |
+
save_dict = {"images": [], "annotations": [], "categories": []}
|
| 87 |
+
for key, value in categories_dict.items():
|
| 88 |
+
type_dict = {"supercategory": "room", "id": value, "name": key}
|
| 89 |
+
save_dict["categories"].append(type_dict)
|
| 90 |
+
return save_dict
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def convert_numpy_to_python(obj):
|
| 94 |
+
if isinstance(obj, np.integer):
|
| 95 |
+
return int(obj)
|
| 96 |
+
elif isinstance(obj, np.floating):
|
| 97 |
+
return float(obj)
|
| 98 |
+
elif isinstance(obj, np.ndarray):
|
| 99 |
+
return obj.tolist()
|
| 100 |
+
else:
|
| 101 |
+
return obj
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def config():
|
| 105 |
+
a = argparse.ArgumentParser(description="Generate coco format data for WAFFLE BENCHMARK SET")
|
| 106 |
+
a.add_argument("--data_root", default="data/waffle/benchmark/", type=str, help="path to WAFFLE BENCHMARK folder")
|
| 107 |
+
a.add_argument("--output", default="data/waffle_benchmark_processed/", type=str, help="path to output folder")
|
| 108 |
+
|
| 109 |
+
args = a.parse_args()
|
| 110 |
+
return args
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
if __name__ == "__main__":
|
| 114 |
+
LABEL_NOTATIONS = {
|
| 115 |
+
"Background": (0, 0, 0), # Black
|
| 116 |
+
"Interior": (255, 255, 255), # White
|
| 117 |
+
"Walls": (255, 0, 0), # Red
|
| 118 |
+
"Doors": (0, 0, 255), # Blue
|
| 119 |
+
"Windows": (0, 255, 255), # Cyan
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
CLASS2INDEX = {
|
| 123 |
+
"Background": 0, # Black
|
| 124 |
+
"Interior": 1, # White
|
| 125 |
+
# "Walls": 2, # Red
|
| 126 |
+
"Doors": 3, # Blue
|
| 127 |
+
"Windows": 4, # Cyan
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
# Create a mapping from RGB values to class indices
|
| 131 |
+
COLOR_TO_CLASS = {
|
| 132 |
+
(0, 0, 0): 0, # Background
|
| 133 |
+
(255, 255, 255): 1, # Interior
|
| 134 |
+
(255, 0, 0): 2, # Walls
|
| 135 |
+
(0, 0, 255): 3, # Doors
|
| 136 |
+
(0, 255, 255): 4, # Windows
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
NEW_CLASS_MAPPING = {
|
| 140 |
+
1: 0,
|
| 141 |
+
3: 1,
|
| 142 |
+
4: 2,
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
CLASS_TO_COLOR = {
|
| 146 |
+
0: (255, 255, 255), # Interior
|
| 147 |
+
1: (0, 0, 255), # Doors
|
| 148 |
+
2: (0, 255, 255), # Windows
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
args = config()
|
| 152 |
+
|
| 153 |
+
root = args.data_root
|
| 154 |
+
image_dir = f"{root}/pngs"
|
| 155 |
+
label_dir = f"{root}/segmented_descrete_pngs"
|
| 156 |
+
input_paths = sorted(glob(f"{label_dir}/*.png"))
|
| 157 |
+
|
| 158 |
+
output_dir = args.output
|
| 159 |
+
output_aux_dir = f"{output_dir}/aux"
|
| 160 |
+
output_image_dir = f"{output_dir}/test/"
|
| 161 |
+
output_annot_dir = f"{output_dir}/annotations/"
|
| 162 |
+
fn_mapping_log = f"{output_annot_dir}/test_image_id_mapping.json"
|
| 163 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 164 |
+
os.makedirs(output_aux_dir, exist_ok=True)
|
| 165 |
+
os.makedirs(output_image_dir, exist_ok=True)
|
| 166 |
+
os.makedirs(output_annot_dir, exist_ok=True)
|
| 167 |
+
|
| 168 |
+
instance_count = 0
|
| 169 |
+
|
| 170 |
+
save_dict = prepare_dict(CLASS2INDEX)
|
| 171 |
+
output_mappings = []
|
| 172 |
+
|
| 173 |
+
for i, path in enumerate(input_paths):
|
| 174 |
+
# if i > 5:
|
| 175 |
+
# exit(0)
|
| 176 |
+
mask = Image.open(path).convert("RGB")
|
| 177 |
+
fn = os.path.basename(path).replace("_seg_colors.png", "")
|
| 178 |
+
new_fn = str(i).zfill(5)
|
| 179 |
+
|
| 180 |
+
mask = np.array(mask)
|
| 181 |
+
image = Image.open(os.path.join(image_dir, f"{fn}.png")).convert("RGB")
|
| 182 |
+
image_width, image_height = image.size
|
| 183 |
+
|
| 184 |
+
# Initialize an empty segmentation mask with the same height and width as the input mask
|
| 185 |
+
segmentation_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.uint8)
|
| 186 |
+
|
| 187 |
+
img_id = i
|
| 188 |
+
img_dict = {}
|
| 189 |
+
img_dict["file_name"] = str(img_id).zfill(5) + ".png"
|
| 190 |
+
img_dict["id"] = img_id
|
| 191 |
+
img_dict["width"] = image_width
|
| 192 |
+
img_dict["height"] = image_height
|
| 193 |
+
|
| 194 |
+
output_polygons = []
|
| 195 |
+
coco_annotation_dict_list = []
|
| 196 |
+
# Iterate over each pixel in the mask and assign the corresponding class index
|
| 197 |
+
for color, class_index in COLOR_TO_CLASS.items():
|
| 198 |
+
# Create a boolean mask for the current color
|
| 199 |
+
color_mask = (mask == color).all(axis=-1)
|
| 200 |
+
color_mask_uint8 = color_mask.astype(np.uint8)
|
| 201 |
+
|
| 202 |
+
# Assign the class index to the segmentation mask
|
| 203 |
+
segmentation_mask[color_mask] = class_index
|
| 204 |
+
|
| 205 |
+
if class_index not in NEW_CLASS_MAPPING:
|
| 206 |
+
continue
|
| 207 |
+
class_index = NEW_CLASS_MAPPING[class_index]
|
| 208 |
+
|
| 209 |
+
# Find contours for the current color mask
|
| 210 |
+
contours, _ = cv2.findContours(color_mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 211 |
+
new_contours = []
|
| 212 |
+
for cnt in contours:
|
| 213 |
+
peri = cv2.arcLength(cnt, True)
|
| 214 |
+
approx = cv2.approxPolyDP(cnt, 0.001 * peri, True)
|
| 215 |
+
new_contours.append(approx)
|
| 216 |
+
|
| 217 |
+
# Convert contours to polygon coordinates
|
| 218 |
+
polygons = [contour.reshape(-1, 2) for contour in new_contours]
|
| 219 |
+
|
| 220 |
+
for polygon in polygons:
|
| 221 |
+
# Convert the polygon to a Shapely Polygon object
|
| 222 |
+
if polygon.shape[0] < 3:
|
| 223 |
+
continue
|
| 224 |
+
|
| 225 |
+
shapely_polygon = Polygon(polygon)
|
| 226 |
+
area = shapely_polygon.area
|
| 227 |
+
rectangle_shapely = shapely_polygon.envelope
|
| 228 |
+
bb_x, bb_y = rectangle_shapely.exterior.xy
|
| 229 |
+
coco_bb = create_coco_bounding_box(bb_x, bb_y, image_width, image_height, bound_pad=2)
|
| 230 |
+
|
| 231 |
+
if class_index in [3, 4] and area < 1:
|
| 232 |
+
continue
|
| 233 |
+
if class_index not in [3, 4] and area < 100:
|
| 234 |
+
continue
|
| 235 |
+
|
| 236 |
+
coco_seg_poly = []
|
| 237 |
+
poly_sorted = resort_corners(polygon)
|
| 238 |
+
# image = draw_polygon_on_image(image, poly_shapely, "test_poly.jpg")
|
| 239 |
+
|
| 240 |
+
for p in poly_sorted:
|
| 241 |
+
coco_seg_poly += list(p)
|
| 242 |
+
|
| 243 |
+
# Create a dictionary for the COCO annotation
|
| 244 |
+
coco_annotation_dict = {
|
| 245 |
+
"segmentation": [coco_seg_poly],
|
| 246 |
+
"area": area,
|
| 247 |
+
"iscrow": 0,
|
| 248 |
+
"image_id": i,
|
| 249 |
+
"bbox": coco_bb,
|
| 250 |
+
"category_id": class_index,
|
| 251 |
+
"id": instance_count,
|
| 252 |
+
}
|
| 253 |
+
coco_annotation_dict_list.append(coco_annotation_dict)
|
| 254 |
+
instance_count += 1
|
| 255 |
+
output_polygons.append([coco_seg_poly, class_index])
|
| 256 |
+
|
| 257 |
+
save_dict["images"].append(img_dict)
|
| 258 |
+
save_dict["annotations"] += coco_annotation_dict_list
|
| 259 |
+
|
| 260 |
+
# Print the unique class indices in the segmentation mask to verify
|
| 261 |
+
print(path)
|
| 262 |
+
print(np.unique(segmentation_mask))
|
| 263 |
+
|
| 264 |
+
filled_mask = fill_mask(segmentation_mask)
|
| 265 |
+
|
| 266 |
+
clean_image = np.array(image)
|
| 267 |
+
filled_mask_resized = cv2.resize(
|
| 268 |
+
filled_mask, (clean_image.shape[1], clean_image.shape[0]), interpolation=cv2.INTER_NEAREST
|
| 269 |
+
)
|
| 270 |
+
cv2.imwrite(f"{output_aux_dir}/{fn}_fg_mask.png", filled_mask_resized)
|
| 271 |
+
|
| 272 |
+
clean_image = clean_image * np.array(filled_mask_resized[:, :, np.newaxis] / 255.0).astype(bool)
|
| 273 |
+
clean_image[filled_mask_resized == 0] = 255
|
| 274 |
+
clean_image = cv2.cvtColor(clean_image, cv2.COLOR_RGB2BGR)
|
| 275 |
+
# clean_image = to_bw_image(clean_image)
|
| 276 |
+
cv2.imwrite(f"{output_image_dir}/{new_fn}.png", clean_image)
|
| 277 |
+
|
| 278 |
+
image_with_polygons = draw_polygon_on_image(np.zeros_like(clean_image), output_polygons, CLASS_TO_COLOR)
|
| 279 |
+
cv2.imwrite(f"{output_aux_dir}/{fn}_polylines.png", image_with_polygons)
|
| 280 |
+
|
| 281 |
+
output_mappings.append(f"{fn} {new_fn}")
|
| 282 |
+
|
| 283 |
+
with open(fn_mapping_log, "w") as f:
|
| 284 |
+
for mapping in output_mappings:
|
| 285 |
+
f.write(f"{mapping}\n")
|
| 286 |
+
|
| 287 |
+
# Serialize save_dict to JSON
|
| 288 |
+
json_path = f"{output_annot_dir}/test.json"
|
| 289 |
+
with open(json_path, "w") as f:
|
| 290 |
+
json.dump(save_dict, f, default=convert_numpy_to_python)
|
datasets/__init__.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .poly_data import build as build_poly
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def build_dataset(image_set, args):
|
| 5 |
+
if args.dataset_name in ["stru3d", "cubicasa", "waffle", "r2g"]:
|
| 6 |
+
print(f"Build {args.dataset_name} {image_set} dataset")
|
| 7 |
+
return build_poly(image_set, args)
|
| 8 |
+
raise ValueError(f"dataset {args.dataset_name} not supported")
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_dataset_class_labels(dataset_name):
|
| 12 |
+
semantics_label = None
|
| 13 |
+
|
| 14 |
+
if dataset_name == "stru3d":
|
| 15 |
+
semantics_label = {
|
| 16 |
+
0: "Living Room",
|
| 17 |
+
1: "Kitchen",
|
| 18 |
+
2: "Bedroom",
|
| 19 |
+
3: "Bathroom",
|
| 20 |
+
4: "Balcony",
|
| 21 |
+
5: "Corridor",
|
| 22 |
+
6: "Dining room",
|
| 23 |
+
7: "Study",
|
| 24 |
+
8: "Studio",
|
| 25 |
+
9: "Store room",
|
| 26 |
+
10: "Garden",
|
| 27 |
+
11: "Laundry room",
|
| 28 |
+
12: "Office",
|
| 29 |
+
13: "Basement",
|
| 30 |
+
14: "Garage",
|
| 31 |
+
15: "Misc.",
|
| 32 |
+
16: "Door",
|
| 33 |
+
17: "Window",
|
| 34 |
+
}
|
| 35 |
+
elif dataset_name == "cubicasa":
|
| 36 |
+
semantics_label = {
|
| 37 |
+
"Outdoor": 0,
|
| 38 |
+
"Kitchen": 1,
|
| 39 |
+
"Living Room": 2,
|
| 40 |
+
"Bed Room": 3,
|
| 41 |
+
"Bath": 4,
|
| 42 |
+
"Entry": 5,
|
| 43 |
+
"Storage": 6,
|
| 44 |
+
"Garage": 7,
|
| 45 |
+
"Undefined": 8,
|
| 46 |
+
"Window": 9,
|
| 47 |
+
"Door": 10,
|
| 48 |
+
}
|
| 49 |
+
elif dataset_name == "r2g":
|
| 50 |
+
semantics_label = {
|
| 51 |
+
"unknown": 0,
|
| 52 |
+
"living_room": 1,
|
| 53 |
+
"kitchen": 2,
|
| 54 |
+
"bedroom": 3,
|
| 55 |
+
"bathroom": 4,
|
| 56 |
+
"restroom": 5,
|
| 57 |
+
"balcony": 6,
|
| 58 |
+
"closet": 7,
|
| 59 |
+
"corridor": 8,
|
| 60 |
+
"washing_room": 9,
|
| 61 |
+
"PS": 10,
|
| 62 |
+
"outside": 11,
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
id2class = {v: k for k, v in semantics_label.items()} if semantics_label else None
|
| 66 |
+
|
| 67 |
+
return semantics_label, id2class
|
datasets/data_utils.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def compute_centroid(polygon):
|
| 6 |
+
"""Compute centroid of a polygon given as list of (x, y)."""
|
| 7 |
+
polygon = np.array(polygon)
|
| 8 |
+
x = np.mean(polygon[:, 0])
|
| 9 |
+
y = np.mean(polygon[:, 1])
|
| 10 |
+
return (x, y)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def get_top_left(polygon):
|
| 14 |
+
return min(polygon, key=lambda p: (p[1], p[0])) # y ascending, x ascending
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def sort_polygons(polygons, tolerance=20, reverse=False):
|
| 18 |
+
# Step 1: Get top-left corner and original index
|
| 19 |
+
indexed = [(i, get_top_left(p), p) for i, p in enumerate(polygons)]
|
| 20 |
+
|
| 21 |
+
# Step 2: Sort by Y (top to bottom)
|
| 22 |
+
indexed.sort(key=lambda x: x[1][1])
|
| 23 |
+
|
| 24 |
+
# Step 3: Group into rows
|
| 25 |
+
rows = []
|
| 26 |
+
for idx, corner, poly in indexed:
|
| 27 |
+
y = corner[1]
|
| 28 |
+
added = False
|
| 29 |
+
for row in rows:
|
| 30 |
+
if abs(row[0][1][1] - y) <= tolerance:
|
| 31 |
+
row.append((idx, corner, poly))
|
| 32 |
+
added = True
|
| 33 |
+
break
|
| 34 |
+
if not added:
|
| 35 |
+
rows.append([(idx, corner, poly)])
|
| 36 |
+
|
| 37 |
+
# Step 4: Sort each row left-to-right
|
| 38 |
+
for row in rows:
|
| 39 |
+
row.sort(key=lambda x: x[1][0]) # sort by x
|
| 40 |
+
|
| 41 |
+
# Step 5: Flatten and return indices
|
| 42 |
+
sorted_indices = [idx for row in rows for idx, _, _ in row]
|
| 43 |
+
if reverse:
|
| 44 |
+
sorted_indices = sorted_indices[::-1]
|
| 45 |
+
sorted_polygons = [polygons[idx] for idx in sorted_indices]
|
| 46 |
+
|
| 47 |
+
return sorted_polygons, sorted_indices
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def plot_polygons(polygons, save_path):
|
| 51 |
+
plt.figure(figsize=(6, 6))
|
| 52 |
+
for i, poly in enumerate(polygons):
|
| 53 |
+
poly = np.array(poly)
|
| 54 |
+
plt.fill(poly[:, 0], poly[:, 1], alpha=0.5, label=f"Polygon {i + 1}")
|
| 55 |
+
centroid = compute_centroid(poly)
|
| 56 |
+
plt.text(centroid[0], centroid[1], f"C{i + 1}", fontsize=10, ha="center")
|
| 57 |
+
# plt.title(title)
|
| 58 |
+
# plt.legend()
|
| 59 |
+
plt.gca().set_aspect("equal", adjustable="box")
|
| 60 |
+
plt.savefig(save_path)
|
datasets/discrete_tokenizer.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class DiscreteTokenizer(object):
|
| 6 |
+
def __init__(self, num_bins, seq_len, add_cls=False):
|
| 7 |
+
self.num_bins = num_bins
|
| 8 |
+
vocab_size = num_bins * num_bins
|
| 9 |
+
self.seq_len = seq_len
|
| 10 |
+
self.add_cls = add_cls
|
| 11 |
+
|
| 12 |
+
self.bos = vocab_size + 0
|
| 13 |
+
self.eos = vocab_size + 1
|
| 14 |
+
self.sep = vocab_size + 2
|
| 15 |
+
self.pad = vocab_size + 3
|
| 16 |
+
if add_cls:
|
| 17 |
+
self.cls = vocab_size + 4
|
| 18 |
+
self.vocab_size = vocab_size + 5
|
| 19 |
+
else:
|
| 20 |
+
self.vocab_size = vocab_size + 4
|
| 21 |
+
|
| 22 |
+
def __len__(self):
|
| 23 |
+
return self.vocab_size
|
| 24 |
+
|
| 25 |
+
def _padding(self, seq, pad_value, dtype):
|
| 26 |
+
if self.seq_len > len(seq):
|
| 27 |
+
seq.extend([pad_value] * (self.seq_len - len(seq)))
|
| 28 |
+
return torch.tensor(np.array(seq), dtype=dtype)
|
| 29 |
+
|
| 30 |
+
def __call__(self, seq, add_bos, add_eos, dtype, return_indices=False):
|
| 31 |
+
out = []
|
| 32 |
+
if add_bos:
|
| 33 |
+
out = [self.bos]
|
| 34 |
+
num_extra = 1 if not self.add_cls else 2 # cls and sep
|
| 35 |
+
indices = []
|
| 36 |
+
for i, sub in enumerate(seq):
|
| 37 |
+
cur_len = len(out)
|
| 38 |
+
# Append sub only if it doesn't exceed seq_len
|
| 39 |
+
if cur_len + len(sub) + num_extra <= self.seq_len:
|
| 40 |
+
out.extend(sub)
|
| 41 |
+
indices.append(i)
|
| 42 |
+
else:
|
| 43 |
+
continue
|
| 44 |
+
# Append cls and sep tokens only if it doesn't exceed seq_len
|
| 45 |
+
if self.add_cls:
|
| 46 |
+
out.append(self.cls) # cls token
|
| 47 |
+
out.append(self.sep)
|
| 48 |
+
# Remove last separator token if present
|
| 49 |
+
if out and out[-1] == self.sep:
|
| 50 |
+
out.pop(-1) # remove last separator token
|
| 51 |
+
|
| 52 |
+
if self.seq_len > len(out):
|
| 53 |
+
out.extend([self.pad] * (self.seq_len - len(out)))
|
| 54 |
+
|
| 55 |
+
if add_eos:
|
| 56 |
+
out[-1] = self.eos
|
| 57 |
+
|
| 58 |
+
if return_indices:
|
| 59 |
+
return torch.tensor(out, dtype=dtype), indices
|
| 60 |
+
return torch.tensor(out, dtype=dtype)
|
datasets/poly_data.py
ADDED
|
@@ -0,0 +1,590 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import os
|
| 3 |
+
from enum import Enum
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.utils.data
|
| 9 |
+
import torchvision
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from pycocotools.coco import COCO
|
| 12 |
+
from torch.utils.data import Dataset
|
| 13 |
+
|
| 14 |
+
from datasets.data_utils import sort_polygons
|
| 15 |
+
from datasets.discrete_tokenizer import DiscreteTokenizer
|
| 16 |
+
from datasets.transforms import ResizeAndPad
|
| 17 |
+
from detectron2.data import transforms as T
|
| 18 |
+
from detectron2.data.detection_utils import annotations_to_instances, transform_instance_annotations
|
| 19 |
+
from detectron2.structures import BoxMode
|
| 20 |
+
from util.poly_ops import resort_corners
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class TokenType(Enum):
|
| 24 |
+
"""0 for <coord>, 1 for <sep>, 2 for <eos>, 3 for <cls>"""
|
| 25 |
+
|
| 26 |
+
coord = 0
|
| 27 |
+
sep = 1
|
| 28 |
+
eos = 2
|
| 29 |
+
cls = 3
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
WD_INDEX = {
|
| 33 |
+
"stru3d": [16, 17],
|
| 34 |
+
"cubicasa": [9, 10],
|
| 35 |
+
"waffle": [],
|
| 36 |
+
"r2g": [],
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class MultiPoly(Dataset):
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
img_folder,
|
| 44 |
+
ann_file,
|
| 45 |
+
transforms,
|
| 46 |
+
semantic_classes,
|
| 47 |
+
dataset_name="",
|
| 48 |
+
image_norm=False,
|
| 49 |
+
poly2seq=False,
|
| 50 |
+
converter_version="v1",
|
| 51 |
+
random_drop_rate=0.0,
|
| 52 |
+
**kwargs,
|
| 53 |
+
):
|
| 54 |
+
super(MultiPoly, self).__init__()
|
| 55 |
+
|
| 56 |
+
self.root = img_folder
|
| 57 |
+
self._transforms = transforms
|
| 58 |
+
self.semantic_classes = semantic_classes
|
| 59 |
+
self.dataset_name = dataset_name
|
| 60 |
+
|
| 61 |
+
self.coco = COCO(ann_file)
|
| 62 |
+
self.ids = list(sorted(self.coco.imgs.keys()))
|
| 63 |
+
|
| 64 |
+
self.poly2seq = poly2seq
|
| 65 |
+
self.prepare = ConvertToCocoDictWithOrder_plus(
|
| 66 |
+
self.root,
|
| 67 |
+
self._transforms,
|
| 68 |
+
image_norm,
|
| 69 |
+
poly2seq,
|
| 70 |
+
semantic_classes=semantic_classes,
|
| 71 |
+
order_type=["l2r", "r2l"][converter_version == "v3_flipped"],
|
| 72 |
+
random_drop_rate=random_drop_rate,
|
| 73 |
+
**kwargs,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
def get_image(self, path):
|
| 77 |
+
return Image.open(os.path.join(self.root, path))
|
| 78 |
+
|
| 79 |
+
def get_vocab_size(self):
|
| 80 |
+
if self.poly2seq:
|
| 81 |
+
return len(self.prepare.tokenizer)
|
| 82 |
+
return None
|
| 83 |
+
|
| 84 |
+
def get_tokenizer(self):
|
| 85 |
+
if self.poly2seq:
|
| 86 |
+
return self.prepare.tokenizer
|
| 87 |
+
return None
|
| 88 |
+
|
| 89 |
+
def __len__(self):
|
| 90 |
+
return len(self.ids)
|
| 91 |
+
|
| 92 |
+
def __getitem__(self, index):
|
| 93 |
+
"""
|
| 94 |
+
Args:
|
| 95 |
+
index (int): Index
|
| 96 |
+
Returns:
|
| 97 |
+
dict: COCO format dict
|
| 98 |
+
"""
|
| 99 |
+
coco = self.coco
|
| 100 |
+
img_id = self.ids[index]
|
| 101 |
+
|
| 102 |
+
ann_ids = coco.getAnnIds(imgIds=img_id)
|
| 103 |
+
target = coco.loadAnns(ann_ids)
|
| 104 |
+
|
| 105 |
+
### Note: here is a hack which assumes door/window have category_id 16, 17 in structured3D
|
| 106 |
+
if self.semantic_classes == -1:
|
| 107 |
+
if self.dataset_name == "stru3d":
|
| 108 |
+
target = [t for t in target if t["category_id"] not in WD_INDEX["stru3d"]]
|
| 109 |
+
# elif self.dataset_name == 'rplan':
|
| 110 |
+
# target = [t for t in target if t['category_id'] not in [9, 11]]
|
| 111 |
+
elif self.dataset_name == "cubicasa":
|
| 112 |
+
target = [t for t in target if t["category_id"] not in WD_INDEX["cubicasa"]]
|
| 113 |
+
|
| 114 |
+
path = coco.loadImgs(img_id)[0]["file_name"]
|
| 115 |
+
|
| 116 |
+
record = self.prepare(img_id, path, target)
|
| 117 |
+
|
| 118 |
+
return record
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class MultiPolyWD(MultiPoly):
|
| 122 |
+
def __getitem__(self, index):
|
| 123 |
+
"""
|
| 124 |
+
Args:
|
| 125 |
+
index (int): Index
|
| 126 |
+
Returns:
|
| 127 |
+
dict: COCO format dict
|
| 128 |
+
"""
|
| 129 |
+
coco = self.coco
|
| 130 |
+
img_id = self.ids[index]
|
| 131 |
+
|
| 132 |
+
ann_ids = coco.getAnnIds(imgIds=img_id)
|
| 133 |
+
target = coco.loadAnns(ann_ids)
|
| 134 |
+
|
| 135 |
+
### Note: here is a hack which assumes door/window have category_id 16, 17 in structured3D
|
| 136 |
+
# if self.semantic_classes == -1:
|
| 137 |
+
# if self.dataset_name == 'stru3d':
|
| 138 |
+
# target = [t for t in target if t['category_id'] not in [16, 17]]
|
| 139 |
+
# elif self.dataset_name == 'rplan':
|
| 140 |
+
# target = [t for t in target if t['category_id'] not in [9, 11]]
|
| 141 |
+
# elif self.dataset_name == 'cubicasa':
|
| 142 |
+
# target = [t for t in target if t['category_id'] not in [9, 10]]
|
| 143 |
+
|
| 144 |
+
if self.dataset_name == "stru3d":
|
| 145 |
+
target = [t for t in target if t["category_id"] in [16, 17]]
|
| 146 |
+
elif self.dataset_name == "rplan":
|
| 147 |
+
target = [t for t in target if t["category_id"] in [9, 11]]
|
| 148 |
+
elif self.dataset_name == "cubicasa":
|
| 149 |
+
target = [t for t in target if t["category_id"] in [9, 10]]
|
| 150 |
+
|
| 151 |
+
path = coco.loadImgs(img_id)[0]["file_name"]
|
| 152 |
+
record = self.prepare(img_id, path, target)
|
| 153 |
+
|
| 154 |
+
return record
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class ConvertToCocoDict(object):
|
| 158 |
+
def __init__(
|
| 159 |
+
self,
|
| 160 |
+
root,
|
| 161 |
+
augmentations,
|
| 162 |
+
image_norm,
|
| 163 |
+
poly2seq=False,
|
| 164 |
+
semantic_classes=-1,
|
| 165 |
+
add_cls_token=False,
|
| 166 |
+
per_token_class=False,
|
| 167 |
+
mask_format="polygon",
|
| 168 |
+
**kwargs,
|
| 169 |
+
):
|
| 170 |
+
self.root = root
|
| 171 |
+
self.augmentations = augmentations
|
| 172 |
+
if image_norm:
|
| 173 |
+
self.image_normalize = torchvision.transforms.Normalize(
|
| 174 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
| 175 |
+
)
|
| 176 |
+
else:
|
| 177 |
+
self.image_normalize = None
|
| 178 |
+
|
| 179 |
+
self.semantic_classes = semantic_classes
|
| 180 |
+
self.poly2seq = poly2seq
|
| 181 |
+
if poly2seq:
|
| 182 |
+
self.tokenizer = DiscreteTokenizer(add_cls=add_cls_token, **kwargs)
|
| 183 |
+
self.add_cls_token = add_cls_token
|
| 184 |
+
self.per_token_class = per_token_class
|
| 185 |
+
self.mask_format = mask_format
|
| 186 |
+
|
| 187 |
+
def _expand_image_dims(self, x):
|
| 188 |
+
if len(x.shape) == 2:
|
| 189 |
+
exp_img = np.expand_dims(x, 0)
|
| 190 |
+
else:
|
| 191 |
+
exp_img = x.transpose((2, 0, 1)) # (h,w,c) -> (c,h,w)
|
| 192 |
+
return exp_img
|
| 193 |
+
|
| 194 |
+
def __call__(self, img_id, path, target):
|
| 195 |
+
|
| 196 |
+
file_name = os.path.join(self.root, path)
|
| 197 |
+
|
| 198 |
+
img = np.array(Image.open(file_name))
|
| 199 |
+
|
| 200 |
+
#### NEW
|
| 201 |
+
if len(img.shape) >= 3:
|
| 202 |
+
if img.shape[-1] > 3: # drop alpha channel
|
| 203 |
+
img = img[:, :, :3]
|
| 204 |
+
w, h = img.shape[:-1]
|
| 205 |
+
else:
|
| 206 |
+
# print(img.shape, file_name)
|
| 207 |
+
w, h = img.shape
|
| 208 |
+
#### NEW
|
| 209 |
+
|
| 210 |
+
record = {}
|
| 211 |
+
record["file_name"] = file_name
|
| 212 |
+
record["height"] = h
|
| 213 |
+
record["width"] = w
|
| 214 |
+
record["image_id"] = img_id
|
| 215 |
+
|
| 216 |
+
for obj in target:
|
| 217 |
+
obj["bbox_mode"] = BoxMode.XYWH_ABS
|
| 218 |
+
|
| 219 |
+
record["annotations"] = target
|
| 220 |
+
|
| 221 |
+
if self.augmentations is None:
|
| 222 |
+
record["image"] = (1 / 255) * torch.as_tensor(np.ascontiguousarray(self._expand_image_dims(img)))
|
| 223 |
+
record["instances"] = annotations_to_instances(target, (h, w), mask_format=self.mask_format)
|
| 224 |
+
else:
|
| 225 |
+
aug_input = T.AugInput(img)
|
| 226 |
+
transforms = self.augmentations(aug_input)
|
| 227 |
+
image = aug_input.image
|
| 228 |
+
record["image"] = (1 / 255) * torch.as_tensor(np.array(self._expand_image_dims(image)))
|
| 229 |
+
h, w = image.shape[:2] # update size
|
| 230 |
+
|
| 231 |
+
annos = [
|
| 232 |
+
transform_instance_annotations(obj, transforms, image.shape[:2])
|
| 233 |
+
for obj in record.pop("annotations")
|
| 234 |
+
if obj.get("iscrowd", 0) == 0
|
| 235 |
+
]
|
| 236 |
+
# resort corners after augmentation: so that all corners start from upper-left counterclockwise
|
| 237 |
+
for anno in annos:
|
| 238 |
+
anno["segmentation"][0] = resort_corners(anno["segmentation"][0])
|
| 239 |
+
|
| 240 |
+
record["instances"] = annotations_to_instances(annos, (h, w), mask_format=self.mask_format)
|
| 241 |
+
|
| 242 |
+
#### NEW ####
|
| 243 |
+
if self.image_normalize is not None:
|
| 244 |
+
record["image"] = self.image_normalize(record["image"])
|
| 245 |
+
|
| 246 |
+
# convert polygons to sequences
|
| 247 |
+
if self.poly2seq:
|
| 248 |
+
# only happend for wdonly
|
| 249 |
+
if not hasattr(record["instances"], "gt_masks"):
|
| 250 |
+
polygons = [np.array([[0.0, 0.0]])]
|
| 251 |
+
polygons_label = [self.semantic_classes - 1] # dummy class
|
| 252 |
+
else:
|
| 253 |
+
polygons = [
|
| 254 |
+
np.clip(np.array(inst).reshape(-1, 2) / (w - 1), 0, 1)
|
| 255 |
+
for inst in record["instances"].gt_masks.polygons
|
| 256 |
+
]
|
| 257 |
+
polygons_label = [inst.item() for inst in record["instances"].gt_classes]
|
| 258 |
+
record.update(
|
| 259 |
+
self._get_bilinear_interpolation_coeffs(
|
| 260 |
+
polygons, polygons_label, self.add_cls_token, self.per_token_class
|
| 261 |
+
)
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
return record
|
| 265 |
+
|
| 266 |
+
def _get_bilinear_interpolation_coeffs(self, polygons, polygons_label, add_cls_token=False, per_token_class=False):
|
| 267 |
+
num_bins = self.tokenizer.num_bins
|
| 268 |
+
quant_poly = [poly * (num_bins - 1) for poly in polygons]
|
| 269 |
+
index11 = [[math.floor(p[0]) * num_bins + math.floor(p[1]) for p in poly] for poly in quant_poly]
|
| 270 |
+
index21 = [[math.ceil(p[0]) * num_bins + math.floor(p[1]) for p in poly] for poly in quant_poly]
|
| 271 |
+
index12 = [[math.floor(p[0]) * num_bins + math.ceil(p[1]) for p in poly] for poly in quant_poly]
|
| 272 |
+
index22 = [[math.ceil(p[0]) * num_bins + math.ceil(p[1]) for p in poly] for poly in quant_poly]
|
| 273 |
+
|
| 274 |
+
seq11 = self.tokenizer(index11, add_bos=True, add_eos=False, dtype=torch.long)
|
| 275 |
+
seq21 = self.tokenizer(index21, add_bos=True, add_eos=False, dtype=torch.long)
|
| 276 |
+
seq12 = self.tokenizer(index12, add_bos=True, add_eos=False, dtype=torch.long)
|
| 277 |
+
seq22 = self.tokenizer(index22, add_bos=True, add_eos=False, dtype=torch.long)
|
| 278 |
+
|
| 279 |
+
# in real values insteads
|
| 280 |
+
target_seq = []
|
| 281 |
+
token_labels = [] # 0 for <coord>, 1 for <sep>, 2 for <eos>, 3 for <cls>
|
| 282 |
+
num_extra = 1 if not add_cls_token else 2 # cls and sep
|
| 283 |
+
count_polys = 0
|
| 284 |
+
for poly in polygons:
|
| 285 |
+
cur_len = len(token_labels)
|
| 286 |
+
if cur_len + len(poly) + num_extra > self.tokenizer.seq_len:
|
| 287 |
+
break # INFO: change from break to continue
|
| 288 |
+
token_labels.extend([TokenType.coord.value] * len(poly))
|
| 289 |
+
if add_cls_token:
|
| 290 |
+
token_labels.append(TokenType.cls.value) # cls token
|
| 291 |
+
token_labels.append(TokenType.sep.value) # separator token
|
| 292 |
+
target_seq.extend(poly)
|
| 293 |
+
if add_cls_token:
|
| 294 |
+
target_seq.append([0, 0]) # padding for cls token
|
| 295 |
+
target_seq.append([0, 0]) # padding for sep/end token
|
| 296 |
+
count_polys += 1
|
| 297 |
+
# remove last separator token
|
| 298 |
+
if len(token_labels) > 0:
|
| 299 |
+
token_labels[-1] = TokenType.eos.value
|
| 300 |
+
mask = torch.ones(self.tokenizer.seq_len, dtype=torch.bool)
|
| 301 |
+
if len(token_labels) < self.tokenizer.seq_len:
|
| 302 |
+
mask[len(token_labels) :] = 0
|
| 303 |
+
target_seq = self.tokenizer._padding(target_seq, [0, 0], dtype=torch.float32)
|
| 304 |
+
token_labels = self.tokenizer._padding(token_labels, -1, dtype=torch.long)
|
| 305 |
+
|
| 306 |
+
delta_x1 = [0] # [0] for bos token
|
| 307 |
+
for polygon in quant_poly[:count_polys]:
|
| 308 |
+
delta = [poly_point[0] - math.floor(poly_point[0]) for poly_point in polygon]
|
| 309 |
+
delta_x1.extend(delta)
|
| 310 |
+
if add_cls_token:
|
| 311 |
+
delta_x1.extend([0]) # for cls token
|
| 312 |
+
delta_x1.extend([0]) # for separator token
|
| 313 |
+
delta_x1 = delta_x1[:-1] # there is no separator token in the end
|
| 314 |
+
delta_x1 = self.tokenizer._padding(delta_x1, 0, dtype=torch.float32)
|
| 315 |
+
delta_x2 = 1 - delta_x1
|
| 316 |
+
|
| 317 |
+
delta_y1 = [0] # [0] for bos token
|
| 318 |
+
for polygon in quant_poly[:count_polys]:
|
| 319 |
+
delta = [poly_point[1] - math.floor(poly_point[1]) for poly_point in polygon]
|
| 320 |
+
delta_y1.extend(delta)
|
| 321 |
+
if add_cls_token:
|
| 322 |
+
delta_y1.extend([0]) # for cls token
|
| 323 |
+
delta_y1.extend([0]) # for separator token
|
| 324 |
+
delta_y1 = delta_y1[:-1] # there is no separator token in the end
|
| 325 |
+
delta_y1 = self.tokenizer._padding(delta_y1, 0, dtype=torch.float32)
|
| 326 |
+
delta_y2 = 1 - delta_y1
|
| 327 |
+
|
| 328 |
+
if not per_token_class:
|
| 329 |
+
target_polygon_labels = polygons_label[:count_polys]
|
| 330 |
+
else:
|
| 331 |
+
target_polygon_labels = []
|
| 332 |
+
for poly, poly_label in zip(quant_poly[:count_polys], polygons_label[:count_polys]):
|
| 333 |
+
target_polygon_labels.extend([poly_label] * len(poly))
|
| 334 |
+
target_polygon_labels.append(self.semantic_classes - 1) # undefined class for <sep> and <eos> token
|
| 335 |
+
|
| 336 |
+
max_label_length = self.tokenizer.seq_len
|
| 337 |
+
if len(polygons_label) < max_label_length:
|
| 338 |
+
target_polygon_labels.extend([-1] * (max_label_length - len(target_polygon_labels)))
|
| 339 |
+
|
| 340 |
+
target_polygon_labels = torch.tensor(target_polygon_labels, dtype=torch.long)
|
| 341 |
+
|
| 342 |
+
return {
|
| 343 |
+
"delta_x1": delta_x1,
|
| 344 |
+
"delta_x2": delta_x2,
|
| 345 |
+
"delta_y1": delta_y1,
|
| 346 |
+
"delta_y2": delta_y2,
|
| 347 |
+
"seq11": seq11,
|
| 348 |
+
"seq21": seq21,
|
| 349 |
+
"seq12": seq12,
|
| 350 |
+
"seq22": seq22,
|
| 351 |
+
"target_seq": target_seq,
|
| 352 |
+
"token_labels": token_labels,
|
| 353 |
+
"mask": mask,
|
| 354 |
+
"target_polygon_labels": target_polygon_labels,
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
class ConvertToCocoDictWithOrder_plus(ConvertToCocoDict):
|
| 359 |
+
def __init__(
|
| 360 |
+
self,
|
| 361 |
+
root,
|
| 362 |
+
augmentations,
|
| 363 |
+
image_norm,
|
| 364 |
+
poly2seq=False,
|
| 365 |
+
semantic_classes=-1,
|
| 366 |
+
add_cls_token=False,
|
| 367 |
+
per_token_class=False,
|
| 368 |
+
mask_format="polygon",
|
| 369 |
+
dataset_name="stru3d",
|
| 370 |
+
order_type="l2r",
|
| 371 |
+
random_drop_rate=0.0,
|
| 372 |
+
**kwargs,
|
| 373 |
+
):
|
| 374 |
+
super().__init__(
|
| 375 |
+
root,
|
| 376 |
+
augmentations,
|
| 377 |
+
image_norm,
|
| 378 |
+
poly2seq,
|
| 379 |
+
semantic_classes,
|
| 380 |
+
add_cls_token,
|
| 381 |
+
per_token_class,
|
| 382 |
+
mask_format,
|
| 383 |
+
**kwargs,
|
| 384 |
+
)
|
| 385 |
+
self.dataset_name = dataset_name
|
| 386 |
+
self.order_type = order_type # l2r, r2l
|
| 387 |
+
self.random_drop_rate = random_drop_rate
|
| 388 |
+
self.tokenizer = DiscreteTokenizer(add_cls=add_cls_token, **kwargs)
|
| 389 |
+
|
| 390 |
+
def _get_bilinear_interpolation_coeffs(self, polygons, polygons_label, add_cls_token=False, per_token_class=False):
|
| 391 |
+
num_bins = self.tokenizer.num_bins
|
| 392 |
+
room_indices = [
|
| 393 |
+
poly_idx
|
| 394 |
+
for poly_idx, poly_label in enumerate(polygons_label)
|
| 395 |
+
if poly_label not in WD_INDEX[self.dataset_name]
|
| 396 |
+
]
|
| 397 |
+
wd_indices = [
|
| 398 |
+
poly_idx for poly_idx, poly_label in enumerate(polygons_label) if poly_label in WD_INDEX[self.dataset_name]
|
| 399 |
+
]
|
| 400 |
+
|
| 401 |
+
_, room_sorted_indices = sort_polygons(
|
| 402 |
+
[polygons[poly_idx] for poly_idx in room_indices], reverse=(self.order_type == "r2l")
|
| 403 |
+
)
|
| 404 |
+
_, wd_sorted_indices = sort_polygons(
|
| 405 |
+
[polygons[poly_idx] for poly_idx in wd_indices], reverse=(self.order_type == "r2l")
|
| 406 |
+
)
|
| 407 |
+
room_indices = [room_indices[_idx] for _idx in room_sorted_indices]
|
| 408 |
+
wd_indices = [wd_indices[_idx] for _idx in wd_sorted_indices]
|
| 409 |
+
|
| 410 |
+
#### NEW ####
|
| 411 |
+
combined_indices = room_indices + wd_indices # room first
|
| 412 |
+
if self.random_drop_rate > 0 and len(combined_indices) > 2:
|
| 413 |
+
keep_indices = np.where(np.random.rand(len(combined_indices)) >= self.random_drop_rate)[0].tolist()
|
| 414 |
+
if len(keep_indices) > 0: # Only apply drop if we have something left
|
| 415 |
+
combined_indices = [combined_indices[i] for i in keep_indices]
|
| 416 |
+
#### NEW ####
|
| 417 |
+
|
| 418 |
+
polygons = [polygons[i] for i in combined_indices]
|
| 419 |
+
polygons_label = [polygons_label[i] for i in combined_indices]
|
| 420 |
+
|
| 421 |
+
quant_poly = [poly * (num_bins - 1) for poly in polygons]
|
| 422 |
+
index11 = [[math.floor(p[0]) * num_bins + math.floor(p[1]) for p in poly] for poly in quant_poly]
|
| 423 |
+
index21 = [[math.ceil(p[0]) * num_bins + math.floor(p[1]) for p in poly] for poly in quant_poly]
|
| 424 |
+
index12 = [[math.floor(p[0]) * num_bins + math.ceil(p[1]) for p in poly] for poly in quant_poly]
|
| 425 |
+
index22 = [[math.ceil(p[0]) * num_bins + math.ceil(p[1]) for p in poly] for poly in quant_poly]
|
| 426 |
+
|
| 427 |
+
seq11 = self.tokenizer(index11, add_bos=True, add_eos=False, dtype=torch.long)
|
| 428 |
+
seq21 = self.tokenizer(index21, add_bos=True, add_eos=False, dtype=torch.long)
|
| 429 |
+
seq12 = self.tokenizer(index12, add_bos=True, add_eos=False, dtype=torch.long)
|
| 430 |
+
seq22, poly_indices = self.tokenizer(
|
| 431 |
+
index22, add_bos=True, add_eos=False, dtype=torch.long, return_indices=True
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
# in real values insteads
|
| 435 |
+
target_seq = []
|
| 436 |
+
token_labels = [] # 0 for <coord>, 1 for <sep>, 2 for <eos>, 3 for <cls>
|
| 437 |
+
|
| 438 |
+
for i in poly_indices:
|
| 439 |
+
token_labels.extend([TokenType.coord.value] * len(polygons[i]))
|
| 440 |
+
if add_cls_token:
|
| 441 |
+
token_labels.append(TokenType.cls.value) # cls token
|
| 442 |
+
token_labels.append(TokenType.sep.value) # separator token
|
| 443 |
+
target_seq.extend(polygons[i])
|
| 444 |
+
if add_cls_token:
|
| 445 |
+
target_seq.append([0, 0]) # padding for cls token
|
| 446 |
+
target_seq.append([0, 0]) # padding for sep/end token
|
| 447 |
+
# remove last separator token
|
| 448 |
+
token_labels[-1] = TokenType.eos.value
|
| 449 |
+
|
| 450 |
+
mask = torch.ones(self.tokenizer.seq_len, dtype=torch.bool)
|
| 451 |
+
if len(token_labels) < self.tokenizer.seq_len:
|
| 452 |
+
mask[len(token_labels) :] = 0
|
| 453 |
+
target_seq = self.tokenizer._padding(target_seq, [0, 0], dtype=torch.float32)
|
| 454 |
+
token_labels = self.tokenizer._padding(token_labels, -1, dtype=torch.long)
|
| 455 |
+
|
| 456 |
+
delta_x1 = [0] # [0] for bos token
|
| 457 |
+
for i in poly_indices:
|
| 458 |
+
polygon = quant_poly[i]
|
| 459 |
+
delta = [poly_point[0] - math.floor(poly_point[0]) for poly_point in polygon]
|
| 460 |
+
delta_x1.extend(delta)
|
| 461 |
+
if add_cls_token:
|
| 462 |
+
delta_x1.extend([0]) # for cls token
|
| 463 |
+
delta_x1.extend([0]) # for separator token
|
| 464 |
+
delta_x1 = delta_x1[:-1] # there is no separator token in the end
|
| 465 |
+
delta_x1 = self.tokenizer._padding(delta_x1, 0, dtype=torch.float32)
|
| 466 |
+
delta_x2 = 1 - delta_x1
|
| 467 |
+
|
| 468 |
+
delta_y1 = [0] # [0] for bos token
|
| 469 |
+
for i in poly_indices:
|
| 470 |
+
polygon = quant_poly[i]
|
| 471 |
+
delta = [poly_point[1] - math.floor(poly_point[1]) for poly_point in polygon]
|
| 472 |
+
delta_y1.extend(delta)
|
| 473 |
+
if add_cls_token:
|
| 474 |
+
delta_y1.extend([0]) # for cls token
|
| 475 |
+
delta_y1.extend([0]) # for separator token
|
| 476 |
+
delta_y1 = delta_y1[:-1] # there is no separator token in the end
|
| 477 |
+
delta_y1 = self.tokenizer._padding(delta_y1, 0, dtype=torch.float32)
|
| 478 |
+
delta_y2 = 1 - delta_y1
|
| 479 |
+
|
| 480 |
+
if not per_token_class:
|
| 481 |
+
target_polygon_labels = [polygons_label[i] for i in poly_indices] # polygons_label[:count_polys]
|
| 482 |
+
input_polygon_labels = torch.tensor(target_polygon_labels.copy(), dtype=torch.long)
|
| 483 |
+
else:
|
| 484 |
+
target_polygon_labels = []
|
| 485 |
+
for i in poly_indices:
|
| 486 |
+
poly, poly_label = quant_poly[i], polygons_label[i]
|
| 487 |
+
target_polygon_labels.extend([poly_label] * len(poly))
|
| 488 |
+
target_polygon_labels.append(self.semantic_classes - 1) # undefined class for <sep> and <eos> token
|
| 489 |
+
input_polygon_labels = torch.tensor(
|
| 490 |
+
[self.semantic_classes - 1] + target_polygon_labels.copy()[:-1], dtype=torch.long
|
| 491 |
+
) # right shift by one: <bos>, ..., <coord>
|
| 492 |
+
|
| 493 |
+
max_label_length = self.tokenizer.seq_len
|
| 494 |
+
if len(polygons_label) < max_label_length:
|
| 495 |
+
target_polygon_labels.extend([-1] * (max_label_length - len(target_polygon_labels)))
|
| 496 |
+
|
| 497 |
+
target_polygon_labels = torch.tensor(target_polygon_labels, dtype=torch.long)
|
| 498 |
+
|
| 499 |
+
return {
|
| 500 |
+
"delta_x1": delta_x1,
|
| 501 |
+
"delta_x2": delta_x2,
|
| 502 |
+
"delta_y1": delta_y1,
|
| 503 |
+
"delta_y2": delta_y2,
|
| 504 |
+
"seq11": seq11,
|
| 505 |
+
"seq21": seq21,
|
| 506 |
+
"seq12": seq12,
|
| 507 |
+
"seq22": seq22,
|
| 508 |
+
"target_seq": target_seq,
|
| 509 |
+
"token_labels": token_labels,
|
| 510 |
+
"mask": mask,
|
| 511 |
+
"target_polygon_labels": target_polygon_labels,
|
| 512 |
+
"input_polygon_labels": input_polygon_labels,
|
| 513 |
+
}
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
def make_poly_transforms(dataset_name, image_set, image_size=256, disable_image_transform=False):
|
| 517 |
+
|
| 518 |
+
trans_list = []
|
| 519 |
+
if dataset_name in ["cubicasa", "waffle"] or (dataset_name == "r2g" and image_size != 512):
|
| 520 |
+
trans_list = [ResizeAndPad((image_size, image_size), pad_value=255)]
|
| 521 |
+
|
| 522 |
+
if image_set == "train":
|
| 523 |
+
if not disable_image_transform:
|
| 524 |
+
trans_list.extend(
|
| 525 |
+
[
|
| 526 |
+
T.RandomFlip(prob=0.5, horizontal=True, vertical=False),
|
| 527 |
+
T.RandomFlip(prob=0.5, horizontal=False, vertical=True),
|
| 528 |
+
T.RandomRotation([0.0, 90.0, 180.0, 270.0], expand=False, center=None, sample_style="choice"),
|
| 529 |
+
]
|
| 530 |
+
)
|
| 531 |
+
return T.AugmentationList(trans_list)
|
| 532 |
+
|
| 533 |
+
if image_set == "val" or image_set == "test":
|
| 534 |
+
return None if len(trans_list) == 0 else T.AugmentationList(trans_list)
|
| 535 |
+
|
| 536 |
+
raise ValueError(f"unknown {image_set}")
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
def build(image_set, args):
|
| 540 |
+
root = Path(args.dataset_root)
|
| 541 |
+
assert root.exists(), f"provided data path {root} does not exist"
|
| 542 |
+
|
| 543 |
+
PATHS = {
|
| 544 |
+
"train": (root / "train", root / "annotations" / "train.json"),
|
| 545 |
+
"val": (root / "val", root / "annotations" / "val.json"),
|
| 546 |
+
"test": (root / "test", root / "annotations" / "test.json"),
|
| 547 |
+
}
|
| 548 |
+
|
| 549 |
+
img_folder, ann_file = PATHS[image_set]
|
| 550 |
+
image_transform = make_poly_transforms(
|
| 551 |
+
args.dataset_name,
|
| 552 |
+
image_set,
|
| 553 |
+
image_size=args.image_size,
|
| 554 |
+
disable_image_transform=getattr(args, "disable_image_transform", False),
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
if args.wd_only:
|
| 558 |
+
dataset = MultiPolyWD(
|
| 559 |
+
img_folder,
|
| 560 |
+
ann_file,
|
| 561 |
+
transforms=image_transform,
|
| 562 |
+
semantic_classes=args.semantic_classes,
|
| 563 |
+
dataset_name=args.dataset_name,
|
| 564 |
+
image_norm=args.image_norm,
|
| 565 |
+
poly2seq=args.poly2seq,
|
| 566 |
+
num_bins=args.num_bins,
|
| 567 |
+
seq_len=args.seq_len,
|
| 568 |
+
add_cls_token=args.add_cls_token,
|
| 569 |
+
per_token_class=args.per_token_sem_loss,
|
| 570 |
+
mask_format=getattr(args, "mask_format", "polygon"),
|
| 571 |
+
)
|
| 572 |
+
else:
|
| 573 |
+
dataset = MultiPoly(
|
| 574 |
+
img_folder,
|
| 575 |
+
ann_file,
|
| 576 |
+
transforms=image_transform,
|
| 577 |
+
semantic_classes=args.semantic_classes,
|
| 578 |
+
dataset_name=args.dataset_name,
|
| 579 |
+
image_norm=args.image_norm,
|
| 580 |
+
poly2seq=args.poly2seq,
|
| 581 |
+
num_bins=args.num_bins,
|
| 582 |
+
seq_len=args.seq_len,
|
| 583 |
+
add_cls_token=args.add_cls_token,
|
| 584 |
+
per_token_class=args.per_token_sem_loss,
|
| 585 |
+
mask_format=getattr(args, "mask_format", "polygon"),
|
| 586 |
+
converter_version=getattr(args, "converter_version", "v1"),
|
| 587 |
+
random_drop_rate=getattr(args, "random_drop_rate", 0.0),
|
| 588 |
+
)
|
| 589 |
+
|
| 590 |
+
return dataset
|
datasets/room_dropout.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from typing import List, Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
from skimage.draw import polygon
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class RoomDropoutStrategy:
|
| 10 |
+
"""
|
| 11 |
+
Strategy for randomly dropping rooms from a density map using ground truth coordinates.
|
| 12 |
+
|
| 13 |
+
Density map: grayscale image where foreground (rooms) are white points and background is black
|
| 14 |
+
GT room coordinates: list of 2D points defining each room's boundary
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, density_map: np.ndarray, room_coordinates: List[List[Tuple[int, int]]]):
|
| 18 |
+
"""
|
| 19 |
+
Initialize the dropout strategy.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
density_map: Grayscale image (H, W) where white pixels represent rooms
|
| 23 |
+
room_coordinates: List of rooms, each room is a list of (x, y) coordinate tuples
|
| 24 |
+
"""
|
| 25 |
+
self.original_density_map = density_map.copy()
|
| 26 |
+
self.room_coordinates = room_coordinates
|
| 27 |
+
self.num_rooms = len(room_coordinates)
|
| 28 |
+
|
| 29 |
+
def create_room_masks(self) -> List[np.ndarray]:
|
| 30 |
+
"""
|
| 31 |
+
Create binary masks for each room using their GT coordinates.
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
List of binary masks, one for each room
|
| 35 |
+
"""
|
| 36 |
+
h, w = self.original_density_map.shape
|
| 37 |
+
room_masks = []
|
| 38 |
+
|
| 39 |
+
for room_coords in self.room_coordinates:
|
| 40 |
+
mask = np.zeros((h, w), dtype=np.uint8)
|
| 41 |
+
|
| 42 |
+
if len(room_coords) >= 3: # Need at least 3 points for a polygon
|
| 43 |
+
# Convert coordinates to numpy array
|
| 44 |
+
coords = np.array(room_coords)
|
| 45 |
+
x_coords = coords[:, 0]
|
| 46 |
+
y_coords = coords[:, 1]
|
| 47 |
+
|
| 48 |
+
# Create polygon mask using skimage
|
| 49 |
+
rr, cc = polygon(y_coords, x_coords, shape=(h, w))
|
| 50 |
+
mask[rr, cc] = 1
|
| 51 |
+
|
| 52 |
+
room_masks.append(mask)
|
| 53 |
+
|
| 54 |
+
return room_masks
|
| 55 |
+
|
| 56 |
+
def drop_rooms_random(self, dropout_rate: float = 0.3, seed: Optional[int] = None) -> Tuple[np.ndarray, List[int]]:
|
| 57 |
+
"""
|
| 58 |
+
Randomly drop rooms from the density map.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
dropout_rate: Fraction of rooms to drop (0.0 to 1.0)
|
| 62 |
+
seed: Random seed for reproducibility
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
Tuple of (modified_density_map, list_of_dropped_room_indices)
|
| 66 |
+
"""
|
| 67 |
+
if seed is not None:
|
| 68 |
+
random.seed(seed)
|
| 69 |
+
np.random.seed(seed)
|
| 70 |
+
|
| 71 |
+
# Determine number of rooms to drop
|
| 72 |
+
num_to_drop = int(self.num_rooms * dropout_rate)
|
| 73 |
+
|
| 74 |
+
# Randomly select room indices to drop
|
| 75 |
+
room_indices = list(range(self.num_rooms))
|
| 76 |
+
dropped_indices = random.sample(room_indices, num_to_drop)
|
| 77 |
+
|
| 78 |
+
return self._apply_dropout(dropped_indices), dropped_indices
|
| 79 |
+
|
| 80 |
+
def drop_rooms_by_indices(self, room_indices: List[int]) -> np.ndarray:
|
| 81 |
+
"""
|
| 82 |
+
Drop specific rooms by their indices.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
room_indices: List of room indices to drop
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
Modified density map with specified rooms removed
|
| 89 |
+
"""
|
| 90 |
+
return self._apply_dropout(room_indices)
|
| 91 |
+
|
| 92 |
+
def drop_rooms_by_area(
|
| 93 |
+
self, min_area: Optional[int] = None, max_area: Optional[int] = None
|
| 94 |
+
) -> Tuple[np.ndarray, List[int]]:
|
| 95 |
+
"""
|
| 96 |
+
Drop rooms based on their area constraints.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
min_area: Minimum area threshold (drop rooms smaller than this)
|
| 100 |
+
max_area: Maximum area threshold (drop rooms larger than this)
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
Tuple of (modified_density_map, list_of_dropped_room_indices)
|
| 104 |
+
"""
|
| 105 |
+
room_masks = self.create_room_masks()
|
| 106 |
+
dropped_indices = []
|
| 107 |
+
|
| 108 |
+
for i, mask in enumerate(room_masks):
|
| 109 |
+
area = np.sum(mask)
|
| 110 |
+
|
| 111 |
+
should_drop = False
|
| 112 |
+
if min_area is not None and area < min_area:
|
| 113 |
+
should_drop = True
|
| 114 |
+
if max_area is not None and area > max_area:
|
| 115 |
+
should_drop = True
|
| 116 |
+
|
| 117 |
+
if should_drop:
|
| 118 |
+
dropped_indices.append(i)
|
| 119 |
+
|
| 120 |
+
return self._apply_dropout(dropped_indices), dropped_indices
|
| 121 |
+
|
| 122 |
+
def _apply_dropout(self, room_indices_to_drop: List[int]) -> np.ndarray:
|
| 123 |
+
"""
|
| 124 |
+
Apply dropout by removing specified rooms from the density map.
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
room_indices_to_drop: List of room indices to remove
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
Modified density map with rooms removed
|
| 131 |
+
"""
|
| 132 |
+
modified_map = self.original_density_map.copy()
|
| 133 |
+
room_masks = self.create_room_masks()
|
| 134 |
+
|
| 135 |
+
# Remove each specified room
|
| 136 |
+
for room_idx in room_indices_to_drop:
|
| 137 |
+
if 0 <= room_idx < len(room_masks):
|
| 138 |
+
mask = room_masks[room_idx]
|
| 139 |
+
# Set pixels in the room area to background (black/0)
|
| 140 |
+
modified_map[mask == 1] = 0
|
| 141 |
+
|
| 142 |
+
return modified_map
|
| 143 |
+
|
| 144 |
+
def visualize_dropout(
|
| 145 |
+
self, original_map: np.ndarray, modified_map: np.ndarray, dropped_indices: List[int]
|
| 146 |
+
) -> np.ndarray:
|
| 147 |
+
"""
|
| 148 |
+
Create a visualization showing the dropout effect.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
original_map: Original density map
|
| 152 |
+
modified_map: Modified density map after dropout
|
| 153 |
+
dropped_indices: Indices of dropped rooms
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
Visualization image with original and modified maps side by side
|
| 157 |
+
"""
|
| 158 |
+
h, w = original_map.shape
|
| 159 |
+
|
| 160 |
+
# Create side-by-side comparison
|
| 161 |
+
vis = np.zeros((h, w * 2), dtype=np.uint8)
|
| 162 |
+
vis[:, :w] = original_map
|
| 163 |
+
vis[:, w:] = modified_map
|
| 164 |
+
|
| 165 |
+
# Highlight dropped rooms in red on the original map
|
| 166 |
+
if len(dropped_indices) > 0:
|
| 167 |
+
room_masks = self.create_room_masks()
|
| 168 |
+
vis_color = cv2.cvtColor(vis, cv2.COLOR_GRAY2BGR)
|
| 169 |
+
|
| 170 |
+
for idx in dropped_indices:
|
| 171 |
+
if 0 <= idx < len(room_masks):
|
| 172 |
+
mask = room_masks[idx]
|
| 173 |
+
# Highlight in red on the left (original) side
|
| 174 |
+
vis_color[mask == 1, 0] = 0 # Blue channel
|
| 175 |
+
vis_color[mask == 1, 1] = 0 # Green channel
|
| 176 |
+
vis_color[mask == 1, 2] = 255 # Red channel
|
| 177 |
+
|
| 178 |
+
return vis_color
|
| 179 |
+
|
| 180 |
+
return cv2.cvtColor(vis, cv2.COLOR_GRAY2BGR)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
# Example usage and testing
|
| 184 |
+
def example_usage():
|
| 185 |
+
"""
|
| 186 |
+
Example of how to use the RoomDropoutStrategy class.
|
| 187 |
+
"""
|
| 188 |
+
# Create a sample density map (200x200 image)
|
| 189 |
+
density_map = np.zeros((200, 200), dtype=np.uint8)
|
| 190 |
+
|
| 191 |
+
# Create some sample room coordinates (rectangles and polygons)
|
| 192 |
+
room_coordinates = [
|
| 193 |
+
# Room 1: Rectangle
|
| 194 |
+
[(20, 20), (80, 20), (80, 60), (20, 60)],
|
| 195 |
+
# Room 2: Another rectangle
|
| 196 |
+
[(100, 30), (180, 30), (180, 80), (100, 80)],
|
| 197 |
+
# Room 3: L-shaped room
|
| 198 |
+
[(30, 100), (90, 100), (90, 130), (60, 130), (60, 160), (30, 160)],
|
| 199 |
+
# Room 4: Triangle
|
| 200 |
+
[(120, 120), (160, 120), (140, 160)],
|
| 201 |
+
# Room 5: Pentagon
|
| 202 |
+
[(50, 180), (70, 170), (90, 180), (80, 195), (40, 195)],
|
| 203 |
+
]
|
| 204 |
+
|
| 205 |
+
# Fill the density map with white pixels for each room
|
| 206 |
+
for room_coords in room_coordinates:
|
| 207 |
+
coords = np.array(room_coords)
|
| 208 |
+
x_coords = coords[:, 0]
|
| 209 |
+
y_coords = coords[:, 1]
|
| 210 |
+
|
| 211 |
+
from skimage.draw import polygon
|
| 212 |
+
|
| 213 |
+
rr, cc = polygon(y_coords, x_coords, shape=density_map.shape)
|
| 214 |
+
density_map[rr, cc] = 255 # White pixels for rooms
|
| 215 |
+
|
| 216 |
+
# Initialize the dropout strategy
|
| 217 |
+
dropout_strategy = RoomDropoutStrategy(density_map, room_coordinates)
|
| 218 |
+
|
| 219 |
+
# Example 1: Random dropout
|
| 220 |
+
print("Example 1: Random dropout (30% of rooms)")
|
| 221 |
+
modified_map1, dropped_indices1 = dropout_strategy.drop_rooms_random(dropout_rate=0.3, seed=42)
|
| 222 |
+
print(f"Dropped rooms: {dropped_indices1}")
|
| 223 |
+
|
| 224 |
+
# Example 2: Drop specific rooms
|
| 225 |
+
print("\nExample 2: Drop specific rooms (indices 0 and 2)")
|
| 226 |
+
modified_map2 = dropout_strategy.drop_rooms_by_indices([0, 2])
|
| 227 |
+
|
| 228 |
+
# Example 3: Drop rooms by area
|
| 229 |
+
print("\nExample 3: Drop rooms with area > 3000 pixels")
|
| 230 |
+
modified_map3, dropped_indices3 = dropout_strategy.drop_rooms_by_area(max_area=3000)
|
| 231 |
+
print(f"Dropped rooms by area: {dropped_indices3}")
|
| 232 |
+
|
| 233 |
+
return density_map, modified_map1, modified_map2, modified_map3
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
if __name__ == "__main__":
|
| 237 |
+
example_usage()
|
datasets/transforms.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image
|
| 2 |
+
|
| 3 |
+
from detectron2.data import transforms as T
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class Resize(T.Augmentation):
|
| 7 |
+
"""Resize image to a fixed target size"""
|
| 8 |
+
|
| 9 |
+
def __init__(self, shape, interp=Image.BICUBIC):
|
| 10 |
+
"""
|
| 11 |
+
Args:
|
| 12 |
+
shape: (h, w) tuple or a int
|
| 13 |
+
interp: PIL interpolation method
|
| 14 |
+
"""
|
| 15 |
+
if isinstance(shape, int):
|
| 16 |
+
shape = (shape, shape)
|
| 17 |
+
shape = tuple(shape)
|
| 18 |
+
self._init(locals())
|
| 19 |
+
|
| 20 |
+
def get_transform(self, image):
|
| 21 |
+
return T.ResizeTransform(image.shape[0], image.shape[1], self.shape[0], self.shape[1], self.interp)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# Custom transform that resizes and then pads to fixed size
|
| 25 |
+
class ResizeAndPad(T.Augmentation):
|
| 26 |
+
def __init__(self, target_size, pad_value=0, interp=Image.BICUBIC):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.target_size = target_size # (height, width)
|
| 29 |
+
self.interp = interp
|
| 30 |
+
self.pad_value = pad_value
|
| 31 |
+
|
| 32 |
+
def get_transform(self, img):
|
| 33 |
+
h, w = img.shape[:2]
|
| 34 |
+
scale = min(self.target_size[0] / h, self.target_size[1] / w)
|
| 35 |
+
new_h, new_w = int(h * scale), int(w * scale)
|
| 36 |
+
|
| 37 |
+
# First resize preserving aspect ratio
|
| 38 |
+
resize_t = T.ResizeTransform(h, w, new_h, new_w, self.interp)
|
| 39 |
+
|
| 40 |
+
# Then pad to target size
|
| 41 |
+
pad_h, pad_w = self.target_size[0] - new_h, self.target_size[1] - new_w
|
| 42 |
+
top = pad_h // 2
|
| 43 |
+
left = pad_w // 2
|
| 44 |
+
pad_t = T.PadTransform(left, top, pad_w - left, pad_h - top, new_h, new_w, pad_value=self.pad_value)
|
| 45 |
+
|
| 46 |
+
return T.TransformList([resize_t, pad_t])
|
detectron2/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
from .utils.env import setup_environment
|
| 4 |
+
|
| 5 |
+
setup_environment()
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# This line will be programatically read/write by setup.py.
|
| 9 |
+
# Leave them at the bottom of this file and don't touch them.
|
| 10 |
+
__version__ = "0.6"
|
detectron2/checkpoint/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
# File:
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
from fvcore.common.checkpoint import Checkpointer, PeriodicCheckpointer
|
| 7 |
+
|
| 8 |
+
from . import catalog as _UNUSED # register the handler
|
| 9 |
+
from .detection_checkpoint import DetectionCheckpointer
|
| 10 |
+
|
| 11 |
+
__all__ = ["Checkpointer", "PeriodicCheckpointer", "DetectionCheckpointer"]
|
detectron2/checkpoint/c2_model_loading.py
ADDED
|
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import copy
|
| 3 |
+
import logging
|
| 4 |
+
import re
|
| 5 |
+
from typing import Dict, List
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from tabulate import tabulate
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def convert_basic_c2_names(original_keys):
|
| 12 |
+
"""
|
| 13 |
+
Apply some basic name conversion to names in C2 weights.
|
| 14 |
+
It only deals with typical backbone models.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
original_keys (list[str]):
|
| 18 |
+
Returns:
|
| 19 |
+
list[str]: The same number of strings matching those in original_keys.
|
| 20 |
+
"""
|
| 21 |
+
layer_keys = copy.deepcopy(original_keys)
|
| 22 |
+
layer_keys = [
|
| 23 |
+
{"pred_b": "linear_b", "pred_w": "linear_w"}.get(k, k) for k in layer_keys
|
| 24 |
+
] # some hard-coded mappings
|
| 25 |
+
|
| 26 |
+
layer_keys = [k.replace("_", ".") for k in layer_keys]
|
| 27 |
+
layer_keys = [re.sub("\\.b$", ".bias", k) for k in layer_keys]
|
| 28 |
+
layer_keys = [re.sub("\\.w$", ".weight", k) for k in layer_keys]
|
| 29 |
+
# Uniform both bn and gn names to "norm"
|
| 30 |
+
layer_keys = [re.sub("bn\\.s$", "norm.weight", k) for k in layer_keys]
|
| 31 |
+
layer_keys = [re.sub("bn\\.bias$", "norm.bias", k) for k in layer_keys]
|
| 32 |
+
layer_keys = [re.sub("bn\\.rm", "norm.running_mean", k) for k in layer_keys]
|
| 33 |
+
layer_keys = [re.sub("bn\\.running.mean$", "norm.running_mean", k) for k in layer_keys]
|
| 34 |
+
layer_keys = [re.sub("bn\\.riv$", "norm.running_var", k) for k in layer_keys]
|
| 35 |
+
layer_keys = [re.sub("bn\\.running.var$", "norm.running_var", k) for k in layer_keys]
|
| 36 |
+
layer_keys = [re.sub("bn\\.gamma$", "norm.weight", k) for k in layer_keys]
|
| 37 |
+
layer_keys = [re.sub("bn\\.beta$", "norm.bias", k) for k in layer_keys]
|
| 38 |
+
layer_keys = [re.sub("gn\\.s$", "norm.weight", k) for k in layer_keys]
|
| 39 |
+
layer_keys = [re.sub("gn\\.bias$", "norm.bias", k) for k in layer_keys]
|
| 40 |
+
|
| 41 |
+
# stem
|
| 42 |
+
layer_keys = [re.sub("^res\\.conv1\\.norm\\.", "conv1.norm.", k) for k in layer_keys]
|
| 43 |
+
# to avoid mis-matching with "conv1" in other components (e.g. detection head)
|
| 44 |
+
layer_keys = [re.sub("^conv1\\.", "stem.conv1.", k) for k in layer_keys]
|
| 45 |
+
|
| 46 |
+
# layer1-4 is used by torchvision, however we follow the C2 naming strategy (res2-5)
|
| 47 |
+
# layer_keys = [re.sub("^res2.", "layer1.", k) for k in layer_keys]
|
| 48 |
+
# layer_keys = [re.sub("^res3.", "layer2.", k) for k in layer_keys]
|
| 49 |
+
# layer_keys = [re.sub("^res4.", "layer3.", k) for k in layer_keys]
|
| 50 |
+
# layer_keys = [re.sub("^res5.", "layer4.", k) for k in layer_keys]
|
| 51 |
+
|
| 52 |
+
# blocks
|
| 53 |
+
layer_keys = [k.replace(".branch1.", ".shortcut.") for k in layer_keys]
|
| 54 |
+
layer_keys = [k.replace(".branch2a.", ".conv1.") for k in layer_keys]
|
| 55 |
+
layer_keys = [k.replace(".branch2b.", ".conv2.") for k in layer_keys]
|
| 56 |
+
layer_keys = [k.replace(".branch2c.", ".conv3.") for k in layer_keys]
|
| 57 |
+
|
| 58 |
+
# DensePose substitutions
|
| 59 |
+
layer_keys = [re.sub("^body.conv.fcn", "body_conv_fcn", k) for k in layer_keys]
|
| 60 |
+
layer_keys = [k.replace("AnnIndex.lowres", "ann_index_lowres") for k in layer_keys]
|
| 61 |
+
layer_keys = [k.replace("Index.UV.lowres", "index_uv_lowres") for k in layer_keys]
|
| 62 |
+
layer_keys = [k.replace("U.lowres", "u_lowres") for k in layer_keys]
|
| 63 |
+
layer_keys = [k.replace("V.lowres", "v_lowres") for k in layer_keys]
|
| 64 |
+
return layer_keys
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def convert_c2_detectron_names(weights):
|
| 68 |
+
"""
|
| 69 |
+
Map Caffe2 Detectron weight names to Detectron2 names.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
weights (dict): name -> tensor
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
dict: detectron2 names -> tensor
|
| 76 |
+
dict: detectron2 names -> C2 names
|
| 77 |
+
"""
|
| 78 |
+
logger = logging.getLogger(__name__)
|
| 79 |
+
logger.info("Renaming Caffe2 weights ......")
|
| 80 |
+
original_keys = sorted(weights.keys())
|
| 81 |
+
layer_keys = copy.deepcopy(original_keys)
|
| 82 |
+
|
| 83 |
+
layer_keys = convert_basic_c2_names(layer_keys)
|
| 84 |
+
|
| 85 |
+
# --------------------------------------------------------------------------
|
| 86 |
+
# RPN hidden representation conv
|
| 87 |
+
# --------------------------------------------------------------------------
|
| 88 |
+
# FPN case
|
| 89 |
+
# In the C2 model, the RPN hidden layer conv is defined for FPN level 2 and then
|
| 90 |
+
# shared for all other levels, hence the appearance of "fpn2"
|
| 91 |
+
layer_keys = [k.replace("conv.rpn.fpn2", "proposal_generator.rpn_head.conv") for k in layer_keys]
|
| 92 |
+
# Non-FPN case
|
| 93 |
+
layer_keys = [k.replace("conv.rpn", "proposal_generator.rpn_head.conv") for k in layer_keys]
|
| 94 |
+
|
| 95 |
+
# --------------------------------------------------------------------------
|
| 96 |
+
# RPN box transformation conv
|
| 97 |
+
# --------------------------------------------------------------------------
|
| 98 |
+
# FPN case (see note above about "fpn2")
|
| 99 |
+
layer_keys = [k.replace("rpn.bbox.pred.fpn2", "proposal_generator.rpn_head.anchor_deltas") for k in layer_keys]
|
| 100 |
+
layer_keys = [
|
| 101 |
+
k.replace("rpn.cls.logits.fpn2", "proposal_generator.rpn_head.objectness_logits") for k in layer_keys
|
| 102 |
+
]
|
| 103 |
+
# Non-FPN case
|
| 104 |
+
layer_keys = [k.replace("rpn.bbox.pred", "proposal_generator.rpn_head.anchor_deltas") for k in layer_keys]
|
| 105 |
+
layer_keys = [k.replace("rpn.cls.logits", "proposal_generator.rpn_head.objectness_logits") for k in layer_keys]
|
| 106 |
+
|
| 107 |
+
# --------------------------------------------------------------------------
|
| 108 |
+
# Fast R-CNN box head
|
| 109 |
+
# --------------------------------------------------------------------------
|
| 110 |
+
layer_keys = [re.sub("^bbox\\.pred", "bbox_pred", k) for k in layer_keys]
|
| 111 |
+
layer_keys = [re.sub("^cls\\.score", "cls_score", k) for k in layer_keys]
|
| 112 |
+
layer_keys = [re.sub("^fc6\\.", "box_head.fc1.", k) for k in layer_keys]
|
| 113 |
+
layer_keys = [re.sub("^fc7\\.", "box_head.fc2.", k) for k in layer_keys]
|
| 114 |
+
# 4conv1fc head tensor names: head_conv1_w, head_conv1_gn_s
|
| 115 |
+
layer_keys = [re.sub("^head\\.conv", "box_head.conv", k) for k in layer_keys]
|
| 116 |
+
|
| 117 |
+
# --------------------------------------------------------------------------
|
| 118 |
+
# FPN lateral and output convolutions
|
| 119 |
+
# --------------------------------------------------------------------------
|
| 120 |
+
def fpn_map(name):
|
| 121 |
+
"""
|
| 122 |
+
Look for keys with the following patterns:
|
| 123 |
+
1) Starts with "fpn.inner."
|
| 124 |
+
Example: "fpn.inner.res2.2.sum.lateral.weight"
|
| 125 |
+
Meaning: These are lateral pathway convolutions
|
| 126 |
+
2) Starts with "fpn.res"
|
| 127 |
+
Example: "fpn.res2.2.sum.weight"
|
| 128 |
+
Meaning: These are FPN output convolutions
|
| 129 |
+
"""
|
| 130 |
+
splits = name.split(".")
|
| 131 |
+
norm = ".norm" if "norm" in splits else ""
|
| 132 |
+
if name.startswith("fpn.inner."):
|
| 133 |
+
# splits example: ['fpn', 'inner', 'res2', '2', 'sum', 'lateral', 'weight']
|
| 134 |
+
stage = int(splits[2][len("res") :])
|
| 135 |
+
return "fpn_lateral{}{}.{}".format(stage, norm, splits[-1])
|
| 136 |
+
elif name.startswith("fpn.res"):
|
| 137 |
+
# splits example: ['fpn', 'res2', '2', 'sum', 'weight']
|
| 138 |
+
stage = int(splits[1][len("res") :])
|
| 139 |
+
return "fpn_output{}{}.{}".format(stage, norm, splits[-1])
|
| 140 |
+
return name
|
| 141 |
+
|
| 142 |
+
layer_keys = [fpn_map(k) for k in layer_keys]
|
| 143 |
+
|
| 144 |
+
# --------------------------------------------------------------------------
|
| 145 |
+
# Mask R-CNN mask head
|
| 146 |
+
# --------------------------------------------------------------------------
|
| 147 |
+
# roi_heads.StandardROIHeads case
|
| 148 |
+
layer_keys = [k.replace(".[mask].fcn", "mask_head.mask_fcn") for k in layer_keys]
|
| 149 |
+
layer_keys = [re.sub("^\\.mask\\.fcn", "mask_head.mask_fcn", k) for k in layer_keys]
|
| 150 |
+
layer_keys = [k.replace("mask.fcn.logits", "mask_head.predictor") for k in layer_keys]
|
| 151 |
+
# roi_heads.Res5ROIHeads case
|
| 152 |
+
layer_keys = [k.replace("conv5.mask", "mask_head.deconv") for k in layer_keys]
|
| 153 |
+
|
| 154 |
+
# --------------------------------------------------------------------------
|
| 155 |
+
# Keypoint R-CNN head
|
| 156 |
+
# --------------------------------------------------------------------------
|
| 157 |
+
# interestingly, the keypoint head convs have blob names that are simply "conv_fcnX"
|
| 158 |
+
layer_keys = [k.replace("conv.fcn", "roi_heads.keypoint_head.conv_fcn") for k in layer_keys]
|
| 159 |
+
layer_keys = [k.replace("kps.score.lowres", "roi_heads.keypoint_head.score_lowres") for k in layer_keys]
|
| 160 |
+
layer_keys = [k.replace("kps.score.", "roi_heads.keypoint_head.score.") for k in layer_keys]
|
| 161 |
+
|
| 162 |
+
# --------------------------------------------------------------------------
|
| 163 |
+
# Done with replacements
|
| 164 |
+
# --------------------------------------------------------------------------
|
| 165 |
+
assert len(set(layer_keys)) == len(layer_keys)
|
| 166 |
+
assert len(original_keys) == len(layer_keys)
|
| 167 |
+
|
| 168 |
+
new_weights = {}
|
| 169 |
+
new_keys_to_original_keys = {}
|
| 170 |
+
for orig, renamed in zip(original_keys, layer_keys):
|
| 171 |
+
new_keys_to_original_keys[renamed] = orig
|
| 172 |
+
if renamed.startswith("bbox_pred.") or renamed.startswith("mask_head.predictor."):
|
| 173 |
+
# remove the meaningless prediction weight for background class
|
| 174 |
+
new_start_idx = 4 if renamed.startswith("bbox_pred.") else 1
|
| 175 |
+
new_weights[renamed] = weights[orig][new_start_idx:]
|
| 176 |
+
logger.info(
|
| 177 |
+
"Remove prediction weight for background class in {}. The shape changes from "
|
| 178 |
+
"{} to {}.".format(renamed, tuple(weights[orig].shape), tuple(new_weights[renamed].shape))
|
| 179 |
+
)
|
| 180 |
+
elif renamed.startswith("cls_score."):
|
| 181 |
+
# move weights of bg class from original index 0 to last index
|
| 182 |
+
logger.info(
|
| 183 |
+
"Move classification weights for background class in {} from index 0 to "
|
| 184 |
+
"index {}.".format(renamed, weights[orig].shape[0] - 1)
|
| 185 |
+
)
|
| 186 |
+
new_weights[renamed] = torch.cat([weights[orig][1:], weights[orig][:1]])
|
| 187 |
+
else:
|
| 188 |
+
new_weights[renamed] = weights[orig]
|
| 189 |
+
|
| 190 |
+
return new_weights, new_keys_to_original_keys
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
# Note the current matching is not symmetric.
|
| 194 |
+
# it assumes model_state_dict will have longer names.
|
| 195 |
+
def align_and_update_state_dicts(model_state_dict, ckpt_state_dict, c2_conversion=True):
|
| 196 |
+
"""
|
| 197 |
+
Match names between the two state-dict, and returns a new chkpt_state_dict with names
|
| 198 |
+
converted to match model_state_dict with heuristics. The returned dict can be later
|
| 199 |
+
loaded with fvcore checkpointer.
|
| 200 |
+
If `c2_conversion==True`, `ckpt_state_dict` is assumed to be a Caffe2
|
| 201 |
+
model and will be renamed at first.
|
| 202 |
+
|
| 203 |
+
Strategy: suppose that the models that we will create will have prefixes appended
|
| 204 |
+
to each of its keys, for example due to an extra level of nesting that the original
|
| 205 |
+
pre-trained weights from ImageNet won't contain. For example, model.state_dict()
|
| 206 |
+
might return backbone[0].body.res2.conv1.weight, while the pre-trained model contains
|
| 207 |
+
res2.conv1.weight. We thus want to match both parameters together.
|
| 208 |
+
For that, we look for each model weight, look among all loaded keys if there is one
|
| 209 |
+
that is a suffix of the current weight name, and use it if that's the case.
|
| 210 |
+
If multiple matches exist, take the one with longest size
|
| 211 |
+
of the corresponding name. For example, for the same model as before, the pretrained
|
| 212 |
+
weight file can contain both res2.conv1.weight, as well as conv1.weight. In this case,
|
| 213 |
+
we want to match backbone[0].body.conv1.weight to conv1.weight, and
|
| 214 |
+
backbone[0].body.res2.conv1.weight to res2.conv1.weight.
|
| 215 |
+
"""
|
| 216 |
+
model_keys = sorted(model_state_dict.keys())
|
| 217 |
+
if c2_conversion:
|
| 218 |
+
ckpt_state_dict, original_keys = convert_c2_detectron_names(ckpt_state_dict)
|
| 219 |
+
# original_keys: the name in the original dict (before renaming)
|
| 220 |
+
else:
|
| 221 |
+
original_keys = {x: x for x in ckpt_state_dict.keys()}
|
| 222 |
+
ckpt_keys = sorted(ckpt_state_dict.keys())
|
| 223 |
+
|
| 224 |
+
def match(a, b):
|
| 225 |
+
# Matched ckpt_key should be a complete (starts with '.') suffix.
|
| 226 |
+
# For example, roi_heads.mesh_head.whatever_conv1 does not match conv1,
|
| 227 |
+
# but matches whatever_conv1 or mesh_head.whatever_conv1.
|
| 228 |
+
return a == b or a.endswith("." + b)
|
| 229 |
+
|
| 230 |
+
# get a matrix of string matches, where each (i, j) entry correspond to the size of the
|
| 231 |
+
# ckpt_key string, if it matches
|
| 232 |
+
match_matrix = [len(j) if match(i, j) else 0 for i in model_keys for j in ckpt_keys]
|
| 233 |
+
match_matrix = torch.as_tensor(match_matrix).view(len(model_keys), len(ckpt_keys))
|
| 234 |
+
# use the matched one with longest size in case of multiple matches
|
| 235 |
+
max_match_size, idxs = match_matrix.max(1)
|
| 236 |
+
# remove indices that correspond to no-match
|
| 237 |
+
idxs[max_match_size == 0] = -1
|
| 238 |
+
|
| 239 |
+
logger = logging.getLogger(__name__)
|
| 240 |
+
# matched_pairs (matched checkpoint key --> matched model key)
|
| 241 |
+
matched_keys = {}
|
| 242 |
+
result_state_dict = {}
|
| 243 |
+
for idx_model, idx_ckpt in enumerate(idxs.tolist()):
|
| 244 |
+
if idx_ckpt == -1:
|
| 245 |
+
continue
|
| 246 |
+
key_model = model_keys[idx_model]
|
| 247 |
+
key_ckpt = ckpt_keys[idx_ckpt]
|
| 248 |
+
value_ckpt = ckpt_state_dict[key_ckpt]
|
| 249 |
+
shape_in_model = model_state_dict[key_model].shape
|
| 250 |
+
|
| 251 |
+
if shape_in_model != value_ckpt.shape:
|
| 252 |
+
logger.warning(
|
| 253 |
+
"Shape of {} in checkpoint is {}, while shape of {} in model is {}.".format(
|
| 254 |
+
key_ckpt, value_ckpt.shape, key_model, shape_in_model
|
| 255 |
+
)
|
| 256 |
+
)
|
| 257 |
+
logger.warning("{} will not be loaded. Please double check and see if this is desired.".format(key_ckpt))
|
| 258 |
+
continue
|
| 259 |
+
|
| 260 |
+
assert key_model not in result_state_dict
|
| 261 |
+
result_state_dict[key_model] = value_ckpt
|
| 262 |
+
if key_ckpt in matched_keys: # already added to matched_keys
|
| 263 |
+
logger.error(
|
| 264 |
+
"Ambiguity found for {} in checkpoint!"
|
| 265 |
+
"It matches at least two keys in the model ({} and {}).".format(
|
| 266 |
+
key_ckpt, key_model, matched_keys[key_ckpt]
|
| 267 |
+
)
|
| 268 |
+
)
|
| 269 |
+
raise ValueError("Cannot match one checkpoint key to multiple keys in the model.")
|
| 270 |
+
|
| 271 |
+
matched_keys[key_ckpt] = key_model
|
| 272 |
+
|
| 273 |
+
# logging:
|
| 274 |
+
matched_model_keys = sorted(matched_keys.values())
|
| 275 |
+
if len(matched_model_keys) == 0:
|
| 276 |
+
logger.warning("No weights in checkpoint matched with model.")
|
| 277 |
+
return ckpt_state_dict
|
| 278 |
+
common_prefix = _longest_common_prefix(matched_model_keys)
|
| 279 |
+
rev_matched_keys = {v: k for k, v in matched_keys.items()}
|
| 280 |
+
original_keys = {k: original_keys[rev_matched_keys[k]] for k in matched_model_keys}
|
| 281 |
+
|
| 282 |
+
model_key_groups = _group_keys_by_module(matched_model_keys, original_keys)
|
| 283 |
+
table = []
|
| 284 |
+
memo = set()
|
| 285 |
+
for key_model in matched_model_keys:
|
| 286 |
+
if key_model in memo:
|
| 287 |
+
continue
|
| 288 |
+
if key_model in model_key_groups:
|
| 289 |
+
group = model_key_groups[key_model]
|
| 290 |
+
memo |= set(group)
|
| 291 |
+
shapes = [tuple(model_state_dict[k].shape) for k in group]
|
| 292 |
+
table.append(
|
| 293 |
+
(
|
| 294 |
+
_longest_common_prefix([k[len(common_prefix) :] for k in group]) + "*",
|
| 295 |
+
_group_str([original_keys[k] for k in group]),
|
| 296 |
+
" ".join([str(x).replace(" ", "") for x in shapes]),
|
| 297 |
+
)
|
| 298 |
+
)
|
| 299 |
+
else:
|
| 300 |
+
key_checkpoint = original_keys[key_model]
|
| 301 |
+
shape = str(tuple(model_state_dict[key_model].shape))
|
| 302 |
+
table.append((key_model[len(common_prefix) :], key_checkpoint, shape))
|
| 303 |
+
table_str = tabulate(table, tablefmt="pipe", headers=["Names in Model", "Names in Checkpoint", "Shapes"])
|
| 304 |
+
logger.info(
|
| 305 |
+
"Following weights matched with "
|
| 306 |
+
+ (f"submodule {common_prefix[:-1]}" if common_prefix else "model")
|
| 307 |
+
+ ":\n"
|
| 308 |
+
+ table_str
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
unmatched_ckpt_keys = [k for k in ckpt_keys if k not in set(matched_keys.keys())]
|
| 312 |
+
for k in unmatched_ckpt_keys:
|
| 313 |
+
result_state_dict[k] = ckpt_state_dict[k]
|
| 314 |
+
return result_state_dict
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def _group_keys_by_module(keys: List[str], original_names: Dict[str, str]):
|
| 318 |
+
"""
|
| 319 |
+
Params in the same submodule are grouped together.
|
| 320 |
+
|
| 321 |
+
Args:
|
| 322 |
+
keys: names of all parameters
|
| 323 |
+
original_names: mapping from parameter name to their name in the checkpoint
|
| 324 |
+
|
| 325 |
+
Returns:
|
| 326 |
+
dict[name -> all other names in the same group]
|
| 327 |
+
"""
|
| 328 |
+
|
| 329 |
+
def _submodule_name(key):
|
| 330 |
+
pos = key.rfind(".")
|
| 331 |
+
if pos < 0:
|
| 332 |
+
return None
|
| 333 |
+
prefix = key[: pos + 1]
|
| 334 |
+
return prefix
|
| 335 |
+
|
| 336 |
+
all_submodules = [_submodule_name(k) for k in keys]
|
| 337 |
+
all_submodules = [x for x in all_submodules if x]
|
| 338 |
+
all_submodules = sorted(all_submodules, key=len)
|
| 339 |
+
|
| 340 |
+
ret = {}
|
| 341 |
+
for prefix in all_submodules:
|
| 342 |
+
group = [k for k in keys if k.startswith(prefix)]
|
| 343 |
+
if len(group) <= 1:
|
| 344 |
+
continue
|
| 345 |
+
original_name_lcp = _longest_common_prefix_str([original_names[k] for k in group])
|
| 346 |
+
if len(original_name_lcp) == 0:
|
| 347 |
+
# don't group weights if original names don't share prefix
|
| 348 |
+
continue
|
| 349 |
+
|
| 350 |
+
for k in group:
|
| 351 |
+
if k in ret:
|
| 352 |
+
continue
|
| 353 |
+
ret[k] = group
|
| 354 |
+
return ret
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def _longest_common_prefix(names: List[str]) -> str:
|
| 358 |
+
"""
|
| 359 |
+
["abc.zfg", "abc.zef"] -> "abc."
|
| 360 |
+
"""
|
| 361 |
+
names = [n.split(".") for n in names]
|
| 362 |
+
m1, m2 = min(names), max(names)
|
| 363 |
+
ret = [a for a, b in zip(m1, m2) if a == b]
|
| 364 |
+
ret = ".".join(ret) + "." if len(ret) else ""
|
| 365 |
+
return ret
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def _longest_common_prefix_str(names: List[str]) -> str:
|
| 369 |
+
m1, m2 = min(names), max(names)
|
| 370 |
+
lcp = [a for a, b in zip(m1, m2) if a == b]
|
| 371 |
+
lcp = "".join(lcp)
|
| 372 |
+
return lcp
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def _group_str(names: List[str]) -> str:
|
| 376 |
+
"""
|
| 377 |
+
Turn "common1", "common2", "common3" into "common{1,2,3}"
|
| 378 |
+
"""
|
| 379 |
+
lcp = _longest_common_prefix_str(names)
|
| 380 |
+
rest = [x[len(lcp) :] for x in names]
|
| 381 |
+
rest = "{" + ",".join(rest) + "}"
|
| 382 |
+
ret = lcp + rest
|
| 383 |
+
|
| 384 |
+
# add some simplification for BN specifically
|
| 385 |
+
ret = ret.replace("bn_{beta,running_mean,running_var,gamma}", "bn_*")
|
| 386 |
+
ret = ret.replace("bn_beta,bn_running_mean,bn_running_var,bn_gamma", "bn_*")
|
| 387 |
+
return ret
|
detectron2/checkpoint/catalog.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
+
from detectron2.utils.file_io import PathHandler, PathManager
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ModelCatalog(object):
|
| 8 |
+
"""
|
| 9 |
+
Store mappings from names to third-party models.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
S3_C2_DETECTRON_PREFIX = "https://dl.fbaipublicfiles.com/detectron"
|
| 13 |
+
|
| 14 |
+
# MSRA models have STRIDE_IN_1X1=True. False otherwise.
|
| 15 |
+
# NOTE: all BN models here have fused BN into an affine layer.
|
| 16 |
+
# As a result, you should only load them to a model with "FrozenBN".
|
| 17 |
+
# Loading them to a model with regular BN or SyncBN is wrong.
|
| 18 |
+
# Even when loaded to FrozenBN, it is still different from affine by an epsilon,
|
| 19 |
+
# which should be negligible for training.
|
| 20 |
+
# NOTE: all models here uses PIXEL_STD=[1,1,1]
|
| 21 |
+
# NOTE: Most of the BN models here are no longer used. We use the
|
| 22 |
+
# re-converted pre-trained models under detectron2 model zoo instead.
|
| 23 |
+
C2_IMAGENET_MODELS = {
|
| 24 |
+
"MSRA/R-50": "ImageNetPretrained/MSRA/R-50.pkl",
|
| 25 |
+
"MSRA/R-101": "ImageNetPretrained/MSRA/R-101.pkl",
|
| 26 |
+
"FAIR/R-50-GN": "ImageNetPretrained/47261647/R-50-GN.pkl",
|
| 27 |
+
"FAIR/R-101-GN": "ImageNetPretrained/47592356/R-101-GN.pkl",
|
| 28 |
+
"FAIR/X-101-32x8d": "ImageNetPretrained/20171220/X-101-32x8d.pkl",
|
| 29 |
+
"FAIR/X-101-64x4d": "ImageNetPretrained/FBResNeXt/X-101-64x4d.pkl",
|
| 30 |
+
"FAIR/X-152-32x8d-IN5k": "ImageNetPretrained/25093814/X-152-32x8d-IN5k.pkl",
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
C2_DETECTRON_PATH_FORMAT = "{prefix}/{url}/output/train/{dataset}/{type}/model_final.pkl" # noqa B950
|
| 34 |
+
|
| 35 |
+
C2_DATASET_COCO = "coco_2014_train%3Acoco_2014_valminusminival"
|
| 36 |
+
C2_DATASET_COCO_KEYPOINTS = "keypoints_coco_2014_train%3Akeypoints_coco_2014_valminusminival"
|
| 37 |
+
|
| 38 |
+
# format: {model_name} -> part of the url
|
| 39 |
+
C2_DETECTRON_MODELS = {
|
| 40 |
+
"35857197/e2e_faster_rcnn_R-50-C4_1x": "35857197/12_2017_baselines/e2e_faster_rcnn_R-50-C4_1x.yaml.01_33_49.iAX0mXvW", # noqa B950
|
| 41 |
+
"35857345/e2e_faster_rcnn_R-50-FPN_1x": "35857345/12_2017_baselines/e2e_faster_rcnn_R-50-FPN_1x.yaml.01_36_30.cUF7QR7I", # noqa B950
|
| 42 |
+
"35857890/e2e_faster_rcnn_R-101-FPN_1x": "35857890/12_2017_baselines/e2e_faster_rcnn_R-101-FPN_1x.yaml.01_38_50.sNxI7sX7", # noqa B950
|
| 43 |
+
"36761737/e2e_faster_rcnn_X-101-32x8d-FPN_1x": "36761737/12_2017_baselines/e2e_faster_rcnn_X-101-32x8d-FPN_1x.yaml.06_31_39.5MIHi1fZ", # noqa B950
|
| 44 |
+
"35858791/e2e_mask_rcnn_R-50-C4_1x": "35858791/12_2017_baselines/e2e_mask_rcnn_R-50-C4_1x.yaml.01_45_57.ZgkA7hPB", # noqa B950
|
| 45 |
+
"35858933/e2e_mask_rcnn_R-50-FPN_1x": "35858933/12_2017_baselines/e2e_mask_rcnn_R-50-FPN_1x.yaml.01_48_14.DzEQe4wC", # noqa B950
|
| 46 |
+
"35861795/e2e_mask_rcnn_R-101-FPN_1x": "35861795/12_2017_baselines/e2e_mask_rcnn_R-101-FPN_1x.yaml.02_31_37.KqyEK4tT", # noqa B950
|
| 47 |
+
"36761843/e2e_mask_rcnn_X-101-32x8d-FPN_1x": "36761843/12_2017_baselines/e2e_mask_rcnn_X-101-32x8d-FPN_1x.yaml.06_35_59.RZotkLKI", # noqa B950
|
| 48 |
+
"48616381/e2e_mask_rcnn_R-50-FPN_2x_gn": "GN/48616381/04_2018_gn_baselines/e2e_mask_rcnn_R-50-FPN_2x_gn_0416.13_23_38.bTlTI97Q", # noqa B950
|
| 49 |
+
"37697547/e2e_keypoint_rcnn_R-50-FPN_1x": "37697547/12_2017_baselines/e2e_keypoint_rcnn_R-50-FPN_1x.yaml.08_42_54.kdzV35ao", # noqa B950
|
| 50 |
+
"35998355/rpn_R-50-C4_1x": "35998355/12_2017_baselines/rpn_R-50-C4_1x.yaml.08_00_43.njH5oD9L", # noqa B950
|
| 51 |
+
"35998814/rpn_R-50-FPN_1x": "35998814/12_2017_baselines/rpn_R-50-FPN_1x.yaml.08_06_03.Axg0r179", # noqa B950
|
| 52 |
+
"36225147/fast_R-50-FPN_1x": "36225147/12_2017_baselines/fast_rcnn_R-50-FPN_1x.yaml.08_39_09.L3obSdQ2", # noqa B950
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
@staticmethod
|
| 56 |
+
def get(name):
|
| 57 |
+
if name.startswith("Caffe2Detectron/COCO"):
|
| 58 |
+
return ModelCatalog._get_c2_detectron_baseline(name)
|
| 59 |
+
if name.startswith("ImageNetPretrained/"):
|
| 60 |
+
return ModelCatalog._get_c2_imagenet_pretrained(name)
|
| 61 |
+
raise RuntimeError("model not present in the catalog: {}".format(name))
|
| 62 |
+
|
| 63 |
+
@staticmethod
|
| 64 |
+
def _get_c2_imagenet_pretrained(name):
|
| 65 |
+
prefix = ModelCatalog.S3_C2_DETECTRON_PREFIX
|
| 66 |
+
name = name[len("ImageNetPretrained/") :]
|
| 67 |
+
name = ModelCatalog.C2_IMAGENET_MODELS[name]
|
| 68 |
+
url = "/".join([prefix, name])
|
| 69 |
+
return url
|
| 70 |
+
|
| 71 |
+
@staticmethod
|
| 72 |
+
def _get_c2_detectron_baseline(name):
|
| 73 |
+
name = name[len("Caffe2Detectron/COCO/") :]
|
| 74 |
+
url = ModelCatalog.C2_DETECTRON_MODELS[name]
|
| 75 |
+
if "keypoint_rcnn" in name:
|
| 76 |
+
dataset = ModelCatalog.C2_DATASET_COCO_KEYPOINTS
|
| 77 |
+
else:
|
| 78 |
+
dataset = ModelCatalog.C2_DATASET_COCO
|
| 79 |
+
|
| 80 |
+
if "35998355/rpn_R-50-C4_1x" in name:
|
| 81 |
+
# this one model is somehow different from others ..
|
| 82 |
+
type = "rpn"
|
| 83 |
+
else:
|
| 84 |
+
type = "generalized_rcnn"
|
| 85 |
+
|
| 86 |
+
# Detectron C2 models are stored in the structure defined in `C2_DETECTRON_PATH_FORMAT`.
|
| 87 |
+
url = ModelCatalog.C2_DETECTRON_PATH_FORMAT.format(
|
| 88 |
+
prefix=ModelCatalog.S3_C2_DETECTRON_PREFIX, url=url, type=type, dataset=dataset
|
| 89 |
+
)
|
| 90 |
+
return url
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class ModelCatalogHandler(PathHandler):
|
| 94 |
+
"""
|
| 95 |
+
Resolve URL like catalog://.
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
PREFIX = "catalog://"
|
| 99 |
+
|
| 100 |
+
def _get_supported_prefixes(self):
|
| 101 |
+
return [self.PREFIX]
|
| 102 |
+
|
| 103 |
+
def _get_local_path(self, path, **kwargs):
|
| 104 |
+
logger = logging.getLogger(__name__)
|
| 105 |
+
catalog_path = ModelCatalog.get(path[len(self.PREFIX) :])
|
| 106 |
+
logger.info("Catalog entry {} points to {}".format(path, catalog_path))
|
| 107 |
+
return PathManager.get_local_path(catalog_path, **kwargs)
|
| 108 |
+
|
| 109 |
+
def _open(self, path, mode="r", **kwargs):
|
| 110 |
+
return PathManager.open(self._get_local_path(path), mode, **kwargs)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
PathManager.register_handler(ModelCatalogHandler())
|