anas commited on
Commit
fadb92b
·
1 Parent(s): 54d83cd

Initial deployment of Raster2Seq floor plan vectorization API

Browse files

- Full inference pipeline with Gradio UI and API endpoint
- Docker-based deployment with CUDA 11.8 for GPU inference
- Semantic room detection with polygon output (13 room types)
- Checkpoint auto-downloaded during Docker build

Made-with: Cursor

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +15 -0
  2. Dockerfile +27 -0
  3. LICENSE +21 -0
  4. README.md +24 -5
  5. app.py +303 -0
  6. data_preprocess/README.md +40 -0
  7. data_preprocess/common_utils.py +45 -0
  8. data_preprocess/cubicasa5k/augmentations.py +703 -0
  9. data_preprocess/cubicasa5k/combine_json.py +118 -0
  10. data_preprocess/cubicasa5k/create_coco_cc5k.py +672 -0
  11. data_preprocess/cubicasa5k/floorplan_extraction.py +403 -0
  12. data_preprocess/cubicasa5k/house.py +1131 -0
  13. data_preprocess/cubicasa5k/loaders.py +158 -0
  14. data_preprocess/cubicasa5k/plotting.py +820 -0
  15. data_preprocess/cubicasa5k/run.sh +15 -0
  16. data_preprocess/cubicasa5k/svg_utils.py +746 -0
  17. data_preprocess/raster2graph/combine_json.py +122 -0
  18. data_preprocess/raster2graph/combine_mapping_ids.py +95 -0
  19. data_preprocess/raster2graph/convert_to_coco.py +472 -0
  20. data_preprocess/raster2graph/dataset.py +296 -0
  21. data_preprocess/raster2graph/image_process.py +67 -0
  22. data_preprocess/raster2graph/util/data_utils.py +966 -0
  23. data_preprocess/raster2graph/util/edges_utils.py +46 -0
  24. data_preprocess/raster2graph/util/geom_utils.py +124 -0
  25. data_preprocess/raster2graph/util/graph_utils.py +879 -0
  26. data_preprocess/raster2graph/util/image_id_dict.py +0 -0
  27. data_preprocess/raster2graph/util/math_utils.py +7 -0
  28. data_preprocess/raster2graph/util/mean_std.py +2 -0
  29. data_preprocess/raster2graph/util/metric_utils.py +338 -0
  30. data_preprocess/raster2graph/util/semantics_dict.py +45 -0
  31. data_preprocess/stru3d/PointCloudReaderPanorama.py +253 -0
  32. data_preprocess/stru3d/generate_coco_stru3d.py +199 -0
  33. data_preprocess/stru3d/generate_point_cloud_stru3d.py +32 -0
  34. data_preprocess/stru3d/stru3d_utils.py +244 -0
  35. data_preprocess/tools/plot_data.sh +60 -0
  36. data_preprocess/tools/run_cc5k.sh +15 -0
  37. data_preprocess/tools/run_r2g.sh +12 -0
  38. data_preprocess/tools/run_s3d.sh +22 -0
  39. data_preprocess/tools/run_waffle.sh +3 -0
  40. data_preprocess/waffle/create_coco_waffle_benchmark.py +290 -0
  41. datasets/__init__.py +67 -0
  42. datasets/data_utils.py +60 -0
  43. datasets/discrete_tokenizer.py +60 -0
  44. datasets/poly_data.py +590 -0
  45. datasets/room_dropout.py +237 -0
  46. datasets/transforms.py +46 -0
  47. detectron2/__init__.py +10 -0
  48. detectron2/checkpoint/__init__.py +11 -0
  49. detectron2/checkpoint/c2_model_loading.py +387 -0
  50. detectron2/checkpoint/catalog.py +113 -0
.gitignore ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ *.pyo
4
+ *.egg-info/
5
+ build/
6
+ dist/
7
+ data/
8
+ data2/
9
+ output*/
10
+ wandb*/
11
+ checkpoints/
12
+ slurm_scripts*
13
+ watch_folder
14
+ cross_eval_out
15
+ *.log
Dockerfile ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
2
+
3
+ RUN apt-get update && apt-get install -y \
4
+ python3.10 python3-pip git wget libgl1-mesa-glx libglib2.0-0 \
5
+ && rm -rf /var/lib/apt/lists/*
6
+
7
+ RUN useradd -m -u 1000 user
8
+ WORKDIR /app
9
+ COPY . /app
10
+
11
+ RUN pip3 install torch==2.3.1 torchvision==0.18.1 \
12
+ --index-url https://download.pytorch.org/whl/cu118
13
+ RUN pip3 install -r requirements.txt
14
+ RUN pip3 install gradio gdown
15
+
16
+ RUN cd models/ops && sh make.sh && cd ../..
17
+ RUN cd diff_ras && python3 setup.py build develop && cd ..
18
+
19
+ RUN mkdir -p checkpoints && \
20
+ gdown --fuzzy "https://drive.google.com/file/d/1M32HlYwXw-4Q_uajSCvpbF31UFPzQVHP/view?usp=sharing" \
21
+ -O checkpoints/r2g_res256_ep0849.pth
22
+
23
+ RUN chown -R user:user /app
24
+ USER user
25
+
26
+ EXPOSE 7860
27
+ CMD ["python3", "app.py"]
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Raster2Seq
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,10 +1,29 @@
1
  ---
2
- title: Raster2seq
3
- emoji: 🐢
4
- colorFrom: yellow
5
- colorTo: gray
6
  sdk: docker
7
  pinned: false
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Raster2Seq
3
+ emoji: 🏠
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: docker
7
  pinned: false
8
+ app_port: 7860
9
  ---
10
 
11
+ # Raster2Seq - Floor Plan Vectorization
12
+
13
+ Upload a floor plan image to detect room polygons and their semantic labels.
14
+
15
+ This Space runs the Raster2Seq model for converting raster floor plan images into vectorized polygon sequences with room type classification.
16
+
17
+ ## API Usage
18
+
19
+ This Space exposes a Gradio API. You can call it programmatically:
20
+
21
+ ```python
22
+ from gradio_client import Client
23
+
24
+ client = Client("AGLO-AI/raster2seq")
25
+ result = client.predict(
26
+ image="path/to/floorplan.png",
27
+ api_name="/predict"
28
+ )
29
+ ```
app.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import copy
3
+ import json
4
+ import math
5
+
6
+ import cv2
7
+ import gradio as gr
8
+ import numpy as np
9
+ import torch
10
+ from PIL import Image
11
+ from shapely.geometry import Polygon
12
+
13
+ from datasets.discrete_tokenizer import DiscreteTokenizer
14
+ from datasets.transforms import ResizeAndPad
15
+ from detectron2.data import transforms as T
16
+ from models import build_model
17
+ from util.plot_utils import plot_semantic_rich_floorplan_opencv
18
+
19
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+
21
+ MODEL_ARGS = argparse.Namespace(
22
+ poly2seq=True,
23
+ seq_len=512,
24
+ num_bins=32,
25
+ image_size=256,
26
+ input_channels=3,
27
+ backbone="resnet50",
28
+ dilation=False,
29
+ position_embedding="sine",
30
+ position_embedding_scale=2 * np.pi,
31
+ num_feature_levels=4,
32
+ enc_layers=6,
33
+ dec_layers=6,
34
+ dim_feedforward=1024,
35
+ hidden_dim=256,
36
+ dropout=0.1,
37
+ nheads=8,
38
+ num_queries=800,
39
+ num_polys=20,
40
+ dec_n_points=4,
41
+ enc_n_points=4,
42
+ query_pos_type="sine",
43
+ with_poly_refine=False,
44
+ masked_attn=False,
45
+ semantic_classes=13,
46
+ aux_loss=False,
47
+ dec_attn_concat_src=True,
48
+ pre_decoder_pos_embed=False,
49
+ learnable_dec_pe=False,
50
+ dec_qkv_proj=False,
51
+ per_token_sem_loss=True,
52
+ add_cls_token=False,
53
+ use_anchor=True,
54
+ inject_cls_embed=False,
55
+ device="cuda" if torch.cuda.is_available() else "cpu",
56
+ )
57
+
58
+ R2G_LABEL = {
59
+ 0: "Living Room",
60
+ 1: "Kitchen",
61
+ 2: "Bedroom",
62
+ 3: "Bathroom",
63
+ 4: "Balcony",
64
+ 5: "Corridor",
65
+ 6: "Dining Room",
66
+ 7: "Study",
67
+ 8: "Studio",
68
+ 9: "Store Room",
69
+ 10: "Garden",
70
+ 11: "Laundry Room",
71
+ 12: "Office",
72
+ 13: "Basement",
73
+ 14: "Garage",
74
+ 15: "Undefined",
75
+ 16: "Door",
76
+ 17: "Window",
77
+ }
78
+
79
+
80
+ def _process_predictions(
81
+ pred_corners, i, semantic_rich, image_size, pred_room_label,
82
+ pred_room_logits, per_token_sem_loss, add_cls_token=False,
83
+ ):
84
+ """Extract polygons from poly2seq model output."""
85
+ np_softmax = lambda x: np.exp(x) / np.sum(np.exp(x), axis=-1, keepdims=True)
86
+ pred_corners_per_scene = pred_corners[i]
87
+ room_polys = []
88
+
89
+ if semantic_rich:
90
+ room_types = []
91
+ window_doors = []
92
+ window_doors_types = []
93
+ pred_room_label_per_scene = pred_room_label[i].cpu().numpy()
94
+ pred_room_logit_per_scene = pred_room_logits[i].cpu().numpy()
95
+
96
+ all_room_polys = []
97
+ tmp = []
98
+ all_length_list = [0]
99
+
100
+ for j in range(len(pred_corners_per_scene)):
101
+ if isinstance(pred_corners_per_scene[j], int):
102
+ if pred_corners_per_scene[j] == 2 and tmp:
103
+ all_room_polys.append(tmp)
104
+ all_length_list.append(len(tmp) + 1 + add_cls_token)
105
+ tmp = []
106
+ continue
107
+ tmp.append(pred_corners_per_scene[j])
108
+
109
+ if len(tmp):
110
+ all_room_polys.append(tmp)
111
+ all_length_list.append(len(tmp) + 1 + add_cls_token)
112
+
113
+ start_poly_indices = np.cumsum(all_length_list)
114
+ final_pred_classes = []
115
+
116
+ for j, poly in enumerate(all_room_polys):
117
+ if len(poly) < 2:
118
+ continue
119
+ corners = np.array(poly, dtype=np.float32) * (image_size - 1)
120
+ corners = np.around(corners).astype(np.int32)
121
+
122
+ if not semantic_rich:
123
+ if len(corners) >= 4 and Polygon(corners).area >= 100:
124
+ room_polys.append(corners)
125
+ else:
126
+ if per_token_sem_loss:
127
+ pred_classes, counts = np.unique(
128
+ pred_room_label_per_scene[start_poly_indices[j]:start_poly_indices[j + 1]][:-1],
129
+ return_counts=True,
130
+ )
131
+ pred_class = pred_classes[np.argmax(counts)]
132
+ else:
133
+ pred_class = pred_room_label_per_scene[start_poly_indices[j + 1] - 1]
134
+ final_pred_classes.append(pred_class)
135
+
136
+ if len(corners) >= 3 and Polygon(corners).area >= 100:
137
+ room_polys.append(corners)
138
+ room_types.append(pred_class)
139
+ elif len(corners) == 2:
140
+ window_doors.append(corners)
141
+ window_doors_types.append(pred_class)
142
+
143
+ if not semantic_rich:
144
+ pred_room_label_per_scene = len(all_room_polys) * [-1]
145
+
146
+ return {
147
+ "room_polys": room_polys,
148
+ "room_types": room_types if semantic_rich else None,
149
+ "window_doors": window_doors if semantic_rich else None,
150
+ "window_doors_types": window_doors_types if semantic_rich else None,
151
+ }
152
+
153
+
154
+ @torch.no_grad()
155
+ def generate(model, samples, semantic_rich=False, use_cache=True, per_token_sem_loss=False):
156
+ """Generate room polygons from model predictions (poly2seq mode only)."""
157
+ model.eval()
158
+ image_size = samples[0].size(2)
159
+
160
+ outputs = model.forward_inference(samples, use_cache)
161
+ pred_corners = outputs["gen_out"]
162
+
163
+ bs = outputs["pred_logits"].shape[0]
164
+
165
+ pred_room_label = None
166
+ pred_room_logits = None
167
+ if "pred_room_logits" in outputs:
168
+ pred_room_logits = outputs["pred_room_logits"]
169
+ prob = torch.nn.functional.softmax(pred_room_logits, -1)
170
+ _, pred_room_label = prob[..., :-1].max(-1)
171
+
172
+ result_rooms = []
173
+ result_classes = []
174
+
175
+ for i in range(bs):
176
+ scene_outputs = _process_predictions(
177
+ pred_corners, i, semantic_rich, image_size,
178
+ pred_room_label, pred_room_logits, per_token_sem_loss,
179
+ )
180
+ room_polys = scene_outputs["room_polys"]
181
+ room_types = scene_outputs["room_types"]
182
+ window_doors = scene_outputs["window_doors"]
183
+ window_doors_types = scene_outputs["window_doors_types"]
184
+
185
+ if window_doors:
186
+ result_rooms.append(room_polys + window_doors)
187
+ result_classes.append(room_types + window_doors_types)
188
+ else:
189
+ result_rooms.append(room_polys)
190
+ result_classes.append(room_types)
191
+
192
+ return {"room": result_rooms, "labels": result_classes}
193
+
194
+
195
+ def load_model():
196
+ tokenizer = DiscreteTokenizer(
197
+ MODEL_ARGS.num_bins, MODEL_ARGS.seq_len, add_cls=MODEL_ARGS.add_cls_token
198
+ )
199
+ MODEL_ARGS.vocab_size = len(tokenizer)
200
+
201
+ model = build_model(MODEL_ARGS, train=False, tokenizer=tokenizer)
202
+ model.to(DEVICE)
203
+
204
+ ckpt_path = "checkpoints/r2g_res256_ep0849.pth"
205
+ checkpoint = torch.load(ckpt_path, map_location="cpu")
206
+ ckpt_state_dict = copy.deepcopy(checkpoint["ema"])
207
+ for key in list(ckpt_state_dict.keys()):
208
+ if key.startswith("module."):
209
+ ckpt_state_dict[key[7:]] = ckpt_state_dict.pop(key)
210
+ model.load_state_dict(ckpt_state_dict, strict=False)
211
+
212
+ for param in model.parameters():
213
+ param.requires_grad = False
214
+ model.eval()
215
+ return model
216
+
217
+
218
+ print("Loading model...")
219
+ MODEL = load_model()
220
+ print("Model loaded.")
221
+
222
+ DATA_TRANSFORM = T.AugmentationList(
223
+ [ResizeAndPad((MODEL_ARGS.image_size, MODEL_ARGS.image_size), pad_value=255)]
224
+ )
225
+
226
+
227
+ def preprocess_image(pil_image: Image.Image) -> torch.Tensor:
228
+ image_np = np.array(pil_image.convert("RGB"))
229
+ aug_input = T.AugInput(image_np)
230
+ DATA_TRANSFORM(aug_input)
231
+ image_np = aug_input.image
232
+
233
+ if len(image_np.shape) == 2:
234
+ tensor = np.expand_dims(image_np, 0)
235
+ else:
236
+ tensor = image_np.transpose((2, 0, 1))
237
+
238
+ return (1 / 255) * torch.as_tensor(tensor, dtype=torch.float32)
239
+
240
+
241
+ def predict_floorplan(image: Image.Image):
242
+ if image is None:
243
+ return None, json.dumps({"error": "No image provided"})
244
+
245
+ input_tensor = preprocess_image(image).unsqueeze(0).to(DEVICE)
246
+
247
+ outputs = generate(
248
+ MODEL,
249
+ input_tensor,
250
+ semantic_rich=MODEL_ARGS.semantic_classes > 0,
251
+ use_cache=True,
252
+ per_token_sem_loss=MODEL_ARGS.per_token_sem_loss,
253
+ )
254
+
255
+ pred_rooms = outputs["room"][0]
256
+ pred_labels = outputs["labels"][0]
257
+ image_size = MODEL_ARGS.image_size
258
+
259
+ if pred_labels is None:
260
+ pred_labels = [-1] * len(pred_rooms)
261
+
262
+ result_polygons = []
263
+ for poly, label in zip(pred_rooms, pred_labels):
264
+ coords = poly.astype(float).tolist()
265
+ result_polygons.append({
266
+ "label": R2G_LABEL.get(int(label), "Unknown"),
267
+ "label_id": int(label),
268
+ "polygon": coords,
269
+ })
270
+
271
+ floorplan_map = plot_semantic_rich_floorplan_opencv(
272
+ zip(pred_rooms, pred_labels),
273
+ None,
274
+ door_window_index=[],
275
+ semantics_label_mapping=R2G_LABEL,
276
+ plot_text=True,
277
+ one_color=False,
278
+ is_sem=True,
279
+ img_w=image_size * 2,
280
+ img_h=image_size * 2,
281
+ scale=2,
282
+ )
283
+ if floorplan_map is not None and floorplan_map.size > 0:
284
+ floorplan_rgb = cv2.cvtColor(floorplan_map, cv2.COLOR_BGR2RGB)
285
+ vis_image = Image.fromarray(floorplan_rgb)
286
+ else:
287
+ vis_image = None
288
+
289
+ return vis_image, result_polygons
290
+
291
+
292
+ demo = gr.Interface(
293
+ fn=predict_floorplan,
294
+ inputs=gr.Image(type="pil", label="Floor Plan Image"),
295
+ outputs=[
296
+ gr.Image(type="pil", label="Detected Rooms"),
297
+ gr.JSON(label="Detected Polygons"),
298
+ ],
299
+ title="Raster2Seq - Floor Plan Vectorization",
300
+ description="Upload a floor plan image to detect room polygons and their semantic labels. Returns both a visualization and structured JSON with polygon coordinates.",
301
+ )
302
+
303
+ demo.launch(server_name="0.0.0.0", server_port=7860)
data_preprocess/README.md ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Data preprocessing
2
+
3
+ ### Structured3D
4
+
5
+ Simply download preprocessed data by RoomFormer at [here](https://polybox.ethz.ch/index.php/s/wKYWFsQOXHnkwcG). For more details, please refer to [RoomFormer's instructions](https://github.com/ywyue/RoomFormer/tree/main/data_preprocess).
6
+
7
+ To render binary floorplan images from GT annotations (as used in our paper), run `bash data_preprocess/tools/run_s3d.sh`.
8
+
9
+ ### CubiCasa5K
10
+ Step 1: Download and extract [CubiCasa5K](https://zenodo.org/record/2613548) dataset.
11
+
12
+ Step 2: Run `bash data_preprocess/cubicasa5k/run.sh`.
13
+
14
+ ### Raster2Graph
15
+ The instruction mainly follows Raster2Graph's instruction.
16
+
17
+ Step 1: Due to dataset proprietary restrictions, please apply for access to LIFULL HOME'S Data [here](https://www.nii.ac.jp/dsc/idr/en/lifull/).
18
+
19
+ Step 2: After obtaining access, download only the "photo-rent-madori-full-00" folder, which contains approximately 300,000 images.
20
+
21
+ Step 3: Apply for access to the annotation [here](https://docs.google.com/forms/d/e/1FAIpQLSexqNMjyvPMtPMPN7bSh_1u4Q27LZAT-S9lR_gpipNIMKV5lw/viewform).
22
+
23
+ The package has 3 folders:
24
+ - annot_npy, annot_json: the annotations saved in npy and json, respectively.
25
+ - original_vector_boundary: boundary boxes of "LIFULL HOME'S Data" which is used to create centered 512x512 images.
26
+
27
+ These folders should be saved in the same directory as `photo-rent-madori-full-00`. For example: `data/R2G_hr_dataset/`.
28
+
29
+ Step 4: Run `bash data_preprocess/tools/run_r2g.sh`.
30
+
31
+ ### WAFFLE
32
+
33
+ It is noted that since WAFFLE only provides segmentation masks for a subset of 100 examples, so we only process this subset for the evaluation, not for training.
34
+
35
+ Step 1: Download data at [here](https://tauex-my.sharepoint.com/:f:/g/personal/hadarelor_tauex_tau_ac_il/EqMX9nRbJ9xFiK7dR_m07b8BldS2saoZ4-ockqncJb_Hrg?e=zGIuos)
36
+
37
+ Step 2: Run `bash data_preprocess/tools/run_waffle.sh`.
38
+
39
+ ## Data visualization
40
+ Please refer to this script [tools/plot_data.sh](tools/plot_data.sh).
data_preprocess/common_utils.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import cv2
4
+ import numpy as np
5
+ from plyfile import PlyData
6
+
7
+
8
+ def read_scene_pc(file_path):
9
+ with open(file_path, "rb") as f:
10
+ plydata = PlyData.read(f)
11
+ dtype = plydata["vertex"].data.dtype
12
+ print("dtype of file{}: {}".format(file_path, dtype))
13
+
14
+ points_data = np.array(plydata["vertex"].data.tolist())
15
+
16
+ return points_data
17
+
18
+
19
+ def is_clockwise(points):
20
+ # points is a list of 2d points.
21
+ assert len(points) > 0
22
+ s = 0.0
23
+ for p1, p2 in zip(points, points[1:] + [points[0]]):
24
+ s += (p2[0] - p1[0]) * (p2[1] + p1[1])
25
+ return s > 0.0
26
+
27
+
28
+ def resort_corners(corners):
29
+ # re-find the starting point and sort corners clockwisely
30
+ x_y_square_sum = corners[:, 0] ** 2 + corners[:, 1] ** 2
31
+ start_corner_idx = np.argmin(x_y_square_sum)
32
+
33
+ corners_sorted = np.concatenate([corners[start_corner_idx:], corners[:start_corner_idx]])
34
+
35
+ ## sort points clockwise
36
+ if not is_clockwise(corners_sorted[:, :2].tolist()):
37
+ corners_sorted[1:] = np.flip(corners_sorted[1:], 0)
38
+
39
+ return corners
40
+
41
+
42
+ def export_density(density_map, out_folder, scene_id):
43
+ density_path = os.path.join(out_folder, scene_id + ".png")
44
+ density_uint8 = (density_map * 255).astype(np.uint8)
45
+ cv2.imwrite(density_path, density_uint8)
data_preprocess/cubicasa5k/augmentations.py ADDED
@@ -0,0 +1,703 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from math import inf
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ from floortrans.loaders import svg_utils
8
+
9
+
10
+ class Compose(object):
11
+ def __init__(self, augmentations):
12
+ self.augmentations = augmentations
13
+
14
+ def __call__(self, sample):
15
+ for a in self.augmentations:
16
+ sample = a(sample)
17
+
18
+ return sample
19
+
20
+
21
+ # 0. I
22
+ # 1. I top to right
23
+ # 2. I vertical flip
24
+ # 3. I top to left
25
+ # 4. L horizontal flip
26
+ # 5. L
27
+ # 6. L vertical flip
28
+ # 7. L horizontal and vertical flip
29
+ # 8. T
30
+ # 9. T top to right
31
+ # 10. T top to down
32
+ # 11. T top to left
33
+ # 12. X or +
34
+ # 13. Opening left corner
35
+ # 14. Opening right corner
36
+ # 15. Opening up corner
37
+ # 16. Opening down corer
38
+ # 17. Icon upper left
39
+ # 18. Icon upper right
40
+ # 19. Icon lower left
41
+ # 20. Icon lower right
42
+
43
+
44
+ class RandomRotations(object):
45
+ def __init__(self, format="furu"):
46
+ if format == "furu":
47
+ self.augment = self.furu
48
+ elif format == "cubi":
49
+ self.augment = self.cubi
50
+
51
+ def __call__(self, sample):
52
+ return self.augment(sample)
53
+
54
+ def cubi(self, sample):
55
+ fplan = sample["image"]
56
+ segmentation = sample["label"]
57
+ heatmap_points = sample["heatmaps"]
58
+ scale = sample["scale"]
59
+ num_of_rotations = int(torch.randint(0, 3, (1,)))
60
+ hmapp_convert_map = {
61
+ 0: 1,
62
+ 1: 2,
63
+ 2: 3,
64
+ 3: 0,
65
+ 4: 5,
66
+ 5: 6,
67
+ 6: 7,
68
+ 7: 4,
69
+ 8: 9,
70
+ 9: 10,
71
+ 10: 11,
72
+ 11: 8,
73
+ 12: 12,
74
+ 13: 15,
75
+ 14: 16,
76
+ 15: 14,
77
+ 16: 13,
78
+ 17: 18,
79
+ 18: 20,
80
+ 19: 17,
81
+ 20: 19,
82
+ }
83
+
84
+ for i in range(num_of_rotations):
85
+ fplan = fplan.transpose(2, 1).flip(2)
86
+ segmentation = segmentation.transpose(2, 1).flip(2)
87
+ points_rotated = dict()
88
+ for junction_type, points in heatmap_points.items():
89
+ new_junction_type = hmapp_convert_map[junction_type]
90
+ new_heatmap_points = []
91
+ for point in points:
92
+ x = fplan.shape[1] - 1 - point[1]
93
+ y = point[0]
94
+ # if y > 256 or x > 256:
95
+ # __import__('ipdb').set_trace()
96
+ new_heatmap_points.append([x, y])
97
+
98
+ points_rotated[new_junction_type] = new_heatmap_points
99
+
100
+ heatmap_points = points_rotated
101
+
102
+ sample = {"image": fplan, "label": segmentation, "scale": scale, "heatmaps": heatmap_points}
103
+
104
+ return sample
105
+
106
+ def furu(self, sample):
107
+ fplan = sample["image"]
108
+ segmentation = sample["label"]
109
+ heatmap_points = sample["heatmap_points"]
110
+ num_of_rotations = int(torch.randint(0, 3, (1,)))
111
+ for i in range(num_of_rotations):
112
+ fplan = fplan.transpose(2, 1).flip(2)
113
+ segmentation = segmentation.transpose(2, 1).flip(2)
114
+
115
+ hmapp_convert_map = {
116
+ 0: 1,
117
+ 1: 2,
118
+ 2: 3,
119
+ 3: 0,
120
+ 4: 5,
121
+ 5: 6,
122
+ 6: 7,
123
+ 7: 4,
124
+ 8: 9,
125
+ 9: 10,
126
+ 10: 11,
127
+ 11: 8,
128
+ 12: 12,
129
+ 13: 15,
130
+ 14: 16,
131
+ 15: 14,
132
+ 16: 13,
133
+ 17: 18,
134
+ 18: 20,
135
+ 19: 17,
136
+ 20: 19,
137
+ }
138
+
139
+ points_rotated = dict()
140
+ for junction_type, points in heatmap_points.items():
141
+ new_junction_type = hmapp_convert_map[junction_type]
142
+ new_heatmap_points = []
143
+ for point in points:
144
+ new_heatmap_points.append([fplan.shape[1] - 1 - point[1], point[0]])
145
+
146
+ points_rotated[new_junction_type] = new_heatmap_points
147
+
148
+ heatmap_points = points_rotated
149
+
150
+ sample = {"image": fplan, "label": segmentation, "heatmap_points": heatmap_points}
151
+
152
+ return sample
153
+
154
+
155
+ def clip_heatmaps(heatmaps, minx, maxx, miny, maxy):
156
+ def clip(p):
157
+ return p[0] < maxx and p[0] >= minx and p[1] < maxy and p[1] >= miny
158
+
159
+ res = {}
160
+ for key, value in heatmaps.items():
161
+ res[key] = list(filter(clip, value))
162
+ for i, e in enumerate(res[key]):
163
+ res[key][i] = (e[0] - minx, e[1] - miny)
164
+
165
+ return res
166
+
167
+
168
+ class DictToTensor(object):
169
+ def __init__(self, data_format="cubi"):
170
+ if data_format == "cubi":
171
+ self.augment = self.cubi
172
+ elif data_format == "furukawa":
173
+ self.augment = self.furukawa
174
+
175
+ def __call__(self, sample):
176
+ return self.augment(sample)
177
+
178
+ def cubi(self, sample):
179
+ image, label = sample["image"], sample["label"]
180
+ _, height, width = label.shape
181
+ heatmaps = sample["heatmaps"]
182
+ scale = sample["scale"]
183
+
184
+ heatmap_tensor = np.zeros((21, height, width))
185
+ for channel, coords in heatmaps.items():
186
+ for x, y in coords:
187
+ if x >= width:
188
+ x -= 1
189
+ if y >= height:
190
+ y -= 1
191
+ heatmap_tensor[int(channel), int(y), int(x)] = 1
192
+
193
+ # Gaussian filter
194
+ kernel = svg_utils.get_gaussian2D(int(30 * scale))
195
+ for i, h in enumerate(heatmap_tensor):
196
+ heatmap_tensor[i] = cv2.filter2D(h, -1, kernel)
197
+
198
+ heatmap_tensor = torch.FloatTensor(heatmap_tensor)
199
+
200
+ label = torch.cat((heatmap_tensor, label), 0)
201
+
202
+ return {"image": image, "label": label}
203
+
204
+ def furukawa(self, sample):
205
+ image, label = sample["image"], sample["label"]
206
+ _, height, width = label.shape
207
+ heatmap_points = sample["heatmap_points"]
208
+
209
+ heatmap_tensor = np.zeros((21, height, width))
210
+ for channel, coords in heatmap_points.items():
211
+ for x, y in coords:
212
+ heatmap_tensor[int(channel), int(y), int(x)] = 1
213
+
214
+ # Gaussian filter
215
+ kernel = svg_utils.get_gaussian2D(13)
216
+ for i, h in enumerate(heatmap_tensor):
217
+ heatmap_tensor[i] = cv2.filter2D(h, -1, kernel, borderType=cv2.BORDER_CONSTANT, delta=0)
218
+
219
+ heatmap_tensor = torch.FloatTensor(heatmap_tensor)
220
+
221
+ label = torch.cat((heatmap_tensor, label), 0)
222
+
223
+ return {"image": image, "label": label}
224
+
225
+
226
+ class RotateNTurns(object):
227
+ def rot_tensor(self, t, n):
228
+ # One turn clock wise
229
+ if n == 1:
230
+ t = t.flip(2).transpose(3, 2)
231
+ # One turn counter clock wise
232
+ elif n == -1:
233
+ t = t.transpose(3, 2).flip(2)
234
+ # Two turns clock wise
235
+ elif n == 2:
236
+ t = t.flip(2).flip(3)
237
+
238
+ return t
239
+
240
+ def rot_points(self, t, n):
241
+ # Swapping corner ts
242
+ t_sorted = t.clone().detach()
243
+ # One turn clock wise
244
+ if n == 1:
245
+ # I junctions
246
+ t_sorted[:, 1] = t[:, 0]
247
+ t_sorted[:, 2] = t[:, 1]
248
+ t_sorted[:, 3] = t[:, 2]
249
+ t_sorted[:, 0] = t[:, 3]
250
+ # L junctions
251
+ t_sorted[:, 5] = t[:, 4]
252
+ t_sorted[:, 6] = t[:, 5]
253
+ t_sorted[:, 7] = t[:, 6]
254
+ t_sorted[:, 4] = t[:, 7]
255
+ # T junctions
256
+ t_sorted[:, 9] = t[:, 8]
257
+ t_sorted[:, 10] = t[:, 9]
258
+ t_sorted[:, 11] = t[:, 10]
259
+ t_sorted[:, 8] = t[:, 11]
260
+ # Opening corners
261
+ t_sorted[:, 15] = t[:, 13]
262
+ t_sorted[:, 16] = t[:, 14]
263
+ t_sorted[:, 14] = t[:, 15]
264
+ t_sorted[:, 13] = t[:, 16]
265
+ # Icon corners
266
+ t_sorted[:, 18] = t[:, 17]
267
+ t_sorted[:, 20] = t[:, 18]
268
+ t_sorted[:, 17] = t[:, 19]
269
+ t_sorted[:, 19] = t[:, 20]
270
+ # One turn counter clock wise
271
+ elif n == -1:
272
+ # I junctions
273
+ t_sorted[:, 3] = t[:, 0]
274
+ t_sorted[:, 0] = t[:, 1]
275
+ t_sorted[:, 1] = t[:, 2]
276
+ t_sorted[:, 2] = t[:, 3]
277
+ # L junctions
278
+ t_sorted[:, 7] = t[:, 4]
279
+ t_sorted[:, 4] = t[:, 5]
280
+ t_sorted[:, 5] = t[:, 6]
281
+ t_sorted[:, 6] = t[:, 7]
282
+ # T junctions
283
+ t_sorted[:, 11] = t[:, 8]
284
+ t_sorted[:, 8] = t[:, 9]
285
+ t_sorted[:, 9] = t[:, 10]
286
+ t_sorted[:, 10] = t[:, 11]
287
+ # Opening corners
288
+ t_sorted[:, 16] = t[:, 13]
289
+ t_sorted[:, 15] = t[:, 14]
290
+ t_sorted[:, 13] = t[:, 15]
291
+ t_sorted[:, 14] = t[:, 16]
292
+ # Icon corners
293
+ t_sorted[:, 19] = t[:, 17]
294
+ t_sorted[:, 17] = t[:, 18]
295
+ t_sorted[:, 20] = t[:, 19]
296
+ t_sorted[:, 18] = t[:, 20]
297
+ # Two turns clock wise
298
+ elif n == 2:
299
+ t_sorted = t.clone().detach()
300
+ # I junctions
301
+ t_sorted[:, 2] = t[:, 0]
302
+ t_sorted[:, 3] = t[:, 1]
303
+ t_sorted[:, 0] = t[:, 2]
304
+ t_sorted[:, 4] = t[:, 3]
305
+ # L junctions
306
+ t_sorted[:, 6] = t[:, 4]
307
+ t_sorted[:, 7] = t[:, 5]
308
+ t_sorted[:, 4] = t[:, 6]
309
+ t_sorted[:, 5] = t[:, 7]
310
+ # T junctions
311
+ t_sorted[:, 10] = t[:, 8]
312
+ t_sorted[:, 11] = t[:, 9]
313
+ t_sorted[:, 8] = t[:, 10]
314
+ t_sorted[:, 9] = t[:, 11]
315
+ # Opening corners
316
+ t_sorted[:, 14] = t[:, 13]
317
+ t_sorted[:, 13] = t[:, 14]
318
+ t_sorted[:, 16] = t[:, 15]
319
+ t_sorted[:, 15] = t[:, 16]
320
+ # Icon corners
321
+ t_sorted[:, 20] = t[:, 17]
322
+ t_sorted[:, 19] = t[:, 18]
323
+ t_sorted[:, 18] = t[:, 19]
324
+ t_sorted[:, 17] = t[:, 20]
325
+ elif n == 0:
326
+ return t_sorted
327
+
328
+ return t_sorted
329
+
330
+ def __call__(self, sample, data_type, n):
331
+ if data_type == "tensor":
332
+ return self.rot_tensor(sample, n)
333
+ elif data_type == "points":
334
+ return self.rot_points(sample, n)
335
+
336
+
337
+ class RandomCropToSizeTorch(object):
338
+ def __init__(
339
+ self,
340
+ input_slice=[21, 1, 1],
341
+ size=(256, 256),
342
+ fill=(0, 0),
343
+ data_format="tensor",
344
+ dtype=torch.float32,
345
+ max_size=None,
346
+ ):
347
+ self.size = size
348
+ self.width = size[0]
349
+ self.height = size[1]
350
+ self.dtype = dtype
351
+ self.fill = fill
352
+ self.max_size = max_size
353
+ self.input_slice = input_slice
354
+
355
+ if data_format == "dict":
356
+ self.augment = self.augment_dict
357
+ elif data_format == "tensor":
358
+ self.augment = self.augment_tesor
359
+ elif data_format == "dict furu":
360
+ self.augment = self.augment_dict_furu
361
+
362
+ def __call__(self, sample):
363
+ return self.augment(sample)
364
+
365
+ def augment_tesor(self, sample):
366
+ image, label = sample["image"], sample["label"]
367
+ img_w = image.shape[2]
368
+ img_h = image.shape[1]
369
+ pad_w = int(self.width / 2)
370
+ pad_h = int(self.height / 2)
371
+
372
+ new_w = self.width + max(img_w, self.width)
373
+ new_h = self.height + max(img_h, self.height)
374
+
375
+ new_image = torch.zeros([image.shape[0], new_h, new_w], dtype=self.dtype)
376
+ new_image[:, pad_h : img_h + pad_h, pad_w : img_w + pad_w] = image
377
+
378
+ new_heatmaps = torch.zeros([self.input_slice[0], new_h, new_w], dtype=self.dtype)
379
+ new_heatmaps[:, pad_h : img_h + pad_h, pad_w : img_w + pad_w] = label[: self.input_slice[0]]
380
+
381
+ new_rooms = torch.full((self.input_slice[1], new_h, new_w), self.fill[0])
382
+ new_rooms[:, pad_h : img_h + pad_h, pad_w : img_w + pad_w] = label[self.input_slice[0]]
383
+ new_icons = torch.full((self.input_slice[2], new_h, new_w), self.fill[1])
384
+ new_icons[:, pad_h : img_h + pad_h, pad_w : img_w + pad_w] = label[self.input_slice[0] + self.input_slice[1]]
385
+
386
+ label = torch.cat((new_heatmaps, new_rooms, new_icons), 0)
387
+ image = new_image
388
+
389
+ removed_up = random.randint(0, new_h - self.width)
390
+ removed_left = random.randint(0, new_w - self.height)
391
+
392
+ removed_down = new_h - self.height - removed_up
393
+ removed_right = new_w - self.width - removed_left
394
+
395
+ if removed_down == 0 and removed_right == 0:
396
+ image = image[:, removed_up:, removed_left:]
397
+ label = label[:, removed_up:, removed_left:]
398
+ elif removed_down == 0:
399
+ image = image[:, removed_up:, removed_left:-removed_right]
400
+ label = label[:, removed_up:, removed_left:-removed_right]
401
+ elif removed_right == 0:
402
+ image = image[:, removed_up:-removed_down, removed_left:]
403
+ label = label[:, removed_up:-removed_down, removed_left:]
404
+ else:
405
+ image = image[:, removed_up:-removed_down, removed_left:-removed_right]
406
+ label = label[:, removed_up:-removed_down, removed_left:-removed_right]
407
+
408
+ return {"image": image, "label": label}
409
+
410
+ def augment_dict(self, sample):
411
+ image, label = sample["image"], sample["label"]
412
+ heatmap_points = sample["heatmaps"]
413
+ img_w = image.shape[2]
414
+ img_h = image.shape[1]
415
+ pad_w = int(self.width / 2)
416
+ pad_h = int(self.height / 2)
417
+
418
+ new_w = self.width + img_w
419
+ new_h = self.height + img_h
420
+
421
+ new_image = torch.full([image.shape[0], new_h, new_w], 255)
422
+ new_image[:, pad_h : img_h + pad_h, pad_w : img_w + pad_w] = image
423
+
424
+ new_rooms = torch.full((1, new_h, new_w), self.fill[0])
425
+ new_rooms[:, pad_h : img_h + pad_h, pad_w : img_w + pad_w] = label[0]
426
+ new_icons = torch.full((1, new_h, new_w), self.fill[1])
427
+ new_icons[:, pad_h : img_h + pad_h, pad_w : img_w + pad_w] = label[1]
428
+
429
+ label = torch.cat((new_rooms, new_icons), 0)
430
+ image = new_image
431
+
432
+ removed_up = random.randint(0, new_h - self.width)
433
+ removed_left = random.randint(0, new_w - self.height)
434
+
435
+ removed_down = new_h - self.height - removed_up
436
+ removed_right = new_w - self.width - removed_left
437
+
438
+ new_heatmap_points = dict()
439
+ for junction_type, points in heatmap_points.items():
440
+ new_heatmap_points_per_type = []
441
+ for point in points:
442
+ new_heatmap_points_per_type.append([point[0] + pad_w, point[1] + pad_h])
443
+
444
+ new_heatmap_points[junction_type] = new_heatmap_points_per_type
445
+
446
+ heatmap_points = new_heatmap_points
447
+
448
+ if removed_down == 0 and removed_right == 0:
449
+ image = image[:, removed_up:, removed_left:]
450
+ label = label[:, removed_up:, removed_left:]
451
+ heatmap_points = clip_heatmaps(heatmap_points, removed_left, inf, removed_up, inf)
452
+ elif removed_down == 0:
453
+ image = image[:, removed_up:, removed_left:-removed_right]
454
+ label = label[:, removed_up:, removed_left:-removed_right]
455
+ heatmap_points = clip_heatmaps(heatmap_points, removed_left, removed_left + self.width, removed_up, inf)
456
+ elif removed_right == 0:
457
+ image = image[:, removed_up:-removed_down, removed_left:]
458
+ label = label[:, removed_up:-removed_down, removed_left:]
459
+ heatmap_points = clip_heatmaps(heatmap_points, removed_left, inf, removed_up, removed_up + self.width)
460
+ else:
461
+ image = image[:, removed_up:-removed_down, removed_left:-removed_right]
462
+ label = label[:, removed_up:-removed_down, removed_left:-removed_right]
463
+ heatmap_points = clip_heatmaps(
464
+ heatmap_points, removed_left, removed_left + self.width, removed_up, removed_up + self.height
465
+ )
466
+
467
+ return {"image": image, "label": label, "heatmaps": heatmap_points, "scale": sample["scale"]}
468
+
469
+ def augment_dict_furu(self, sample):
470
+ image, label = sample["image"], sample["label"]
471
+ heatmap_points = sample["heatmap_points"]
472
+ img_w = image.shape[2]
473
+ img_h = image.shape[1]
474
+ pad_w = int(self.width / 2)
475
+ pad_h = int(self.height / 2)
476
+
477
+ new_w = self.width + img_w
478
+ new_h = self.height + img_h
479
+
480
+ new_image = torch.full([image.shape[0], new_h, new_w], 255)
481
+ new_image[:, pad_h : img_h + pad_h, pad_w : img_w + pad_w] = image
482
+
483
+ new_rooms = torch.full((1, new_h, new_w), self.fill[0])
484
+ new_rooms[:, pad_h : img_h + pad_h, pad_w : img_w + pad_w] = label[0]
485
+ new_icons = torch.full((1, new_h, new_w), self.fill[1])
486
+ new_icons[:, pad_h : img_h + pad_h, pad_w : img_w + pad_w] = label[1]
487
+
488
+ label = torch.cat((new_rooms, new_icons), 0)
489
+ image = new_image
490
+
491
+ removed_up = random.randint(0, new_h - self.width)
492
+ removed_left = random.randint(0, new_w - self.height)
493
+
494
+ removed_down = new_h - self.height - removed_up
495
+ removed_right = new_w - self.width - removed_left
496
+
497
+ new_heatmap_points = dict()
498
+ for junction_type, points in heatmap_points.items():
499
+ new_heatmap_points_per_type = []
500
+ for point in points:
501
+ new_heatmap_points_per_type.append([point[0] + pad_w, point[1] + pad_h])
502
+
503
+ new_heatmap_points[junction_type] = new_heatmap_points_per_type
504
+
505
+ heatmap_points = new_heatmap_points
506
+
507
+ if removed_down == 0 and removed_right == 0:
508
+ image = image[:, removed_up:, removed_left:]
509
+ label = label[:, removed_up:, removed_left:]
510
+ heatmap_points = clip_heatmaps(heatmap_points, removed_left, inf, removed_up, inf)
511
+ elif removed_down == 0:
512
+ image = image[:, removed_up:, removed_left:-removed_right]
513
+ label = label[:, removed_up:, removed_left:-removed_right]
514
+ heatmap_points = clip_heatmaps(heatmap_points, removed_left, removed_left + self.width, removed_up, inf)
515
+ elif removed_right == 0:
516
+ image = image[:, removed_up:-removed_down, removed_left:]
517
+ label = label[:, removed_up:-removed_down, removed_left:]
518
+ heatmap_points = clip_heatmaps(heatmap_points, removed_left, inf, removed_up, removed_up + self.width)
519
+ else:
520
+ image = image[:, removed_up:-removed_down, removed_left:-removed_right]
521
+ label = label[:, removed_up:-removed_down, removed_left:-removed_right]
522
+ heatmap_points = clip_heatmaps(
523
+ heatmap_points, removed_left, removed_left + self.width, removed_up, removed_up + self.height
524
+ )
525
+
526
+ return {"image": image, "label": label, "heatmap_points": heatmap_points}
527
+
528
+
529
+ class ColorJitterTorch(object):
530
+ def __init__(self, b_var=0.4, c_var=0.4, s_var=0.4, dtype=torch.float32, version="dict"):
531
+ self.b_var = b_var
532
+ self.c_var = c_var
533
+ self.s_var = s_var
534
+ self.dtype = dtype
535
+ self.version = version
536
+
537
+ def __call__(self, sample):
538
+ res = sample
539
+ image = sample["image"]
540
+ image = self.brightness(image, self.b_var)
541
+ image = self.contrast(image, self.c_var)
542
+ image = self.saturation(image, self.s_var)
543
+ res["image"] = image
544
+
545
+ return res
546
+
547
+ def blend(self, img_1, img_2, var):
548
+ m = torch.tensor([0], dtype=self.dtype).uniform_(-var, var)
549
+ alpha = 1 + m
550
+ res = img_1 * alpha + (1 - alpha) * img_2
551
+ res = torch.clamp(res, min=0, max=255)
552
+
553
+ return res
554
+
555
+ def grayscale(self, img):
556
+ red = img[0] * 0.299
557
+ green = img[1] * 0.587
558
+ blue = img[2] * 0.114
559
+ gray = red + green + blue
560
+ gray = torch.clamp(gray, min=0, max=255)
561
+ res = torch.stack((gray, gray, gray), dim=0)
562
+
563
+ return res
564
+
565
+ def saturation(self, img, var):
566
+ res = self.grayscale(img)
567
+ res = self.blend(img, res, var)
568
+
569
+ return res
570
+
571
+ def brightness(self, img, var):
572
+ res = torch.zeros(img.shape)
573
+ res = self.blend(img, res, var)
574
+
575
+ return res
576
+
577
+ def contrast(self, img, var):
578
+ res = self.grayscale(img)
579
+ mean_color = res.mean()
580
+ res = torch.full(res.shape, mean_color)
581
+ res = self.blend(img, res, var)
582
+
583
+ return res
584
+
585
+
586
+ class ResizePaddedTorch(object):
587
+ def __init__(self, fill, size=(256, 256), both=True, dtype=torch.float32, data_format="tensor"):
588
+ self.size = size
589
+ self.width = size[0]
590
+ self.height = size[1]
591
+ self.both = both
592
+ self.dtype = dtype
593
+ self.fill = fill
594
+ self.fill_cval = 255
595
+ if data_format == "tensor":
596
+ self.augment = self.augment_tensor
597
+ elif data_format == "dict furu":
598
+ self.augment = self.augment_dict_furu
599
+ elif data_format == "dict":
600
+ self.augment = self.augment_dict
601
+ self.fill_cval = 1
602
+
603
+ def __call__(self, sample):
604
+ # image 1: Bi-cubic interpolation as in original paper
605
+ image, _, _, _ = self.resize_padded(
606
+ sample["image"], self.size, fill_cval=self.fill_cval, image=True, mode="bilinear", aling_corners=False
607
+ )
608
+ sample["image"] = image
609
+
610
+ return self.augment(sample)
611
+
612
+ def augment_tensor(self, sample):
613
+ image, label = sample["image"], sample["label"]
614
+
615
+ if self.both:
616
+ # labels 0: Nearest-neighbor interpolation
617
+ heatmaps, _, _, _ = self.resize_padded(label[:21], self.size, mode="bilinear", aling_corners=False)
618
+ rooms_padded, _, _, _ = self.resize_padded(label[[21]], self.size, mode="nearest", fill_cval=self.fill[0])
619
+ icons_padded, _, _, _ = self.resize_padded(
620
+ label[[22]],
621
+ self.size,
622
+ mode="nearest",
623
+ fill_cval=self.fill[1],
624
+ )
625
+ label = torch.cat((heatmaps, rooms_padded, icons_padded), dim=0)
626
+
627
+ return {"image": image, "label": label}
628
+
629
+ def augment_dict_furu(self, sample):
630
+ image, label = sample["image"], sample["label"]
631
+ heatmap_points = sample["heatmap_points"]
632
+
633
+ rooms_padded, _, _, _ = self.resize_padded(label[[0]], self.size, mode="nearest", fill_cval=self.fill[0])
634
+ icons_padded, ratio, y_pad, x_pad = self.resize_padded(
635
+ label[[1]], self.size, mode="nearest", fill_cval=self.fill[1]
636
+ )
637
+ label = torch.cat((rooms_padded, icons_padded), dim=0)
638
+
639
+ new_heatmap_points = dict()
640
+ for junction_type, points in heatmap_points.items():
641
+ new_heatmap_points_per_type = []
642
+ for point in points:
643
+ # Indexing starts from 0 but when we multiply with the ratio we need to start from 0.
644
+ new_x = point[0] * ratio + x_pad
645
+ new_y = point[1] * ratio + y_pad
646
+ new_heatmap_points_per_type.append([new_x, new_y])
647
+ new_heatmap_points[junction_type] = new_heatmap_points_per_type
648
+
649
+ heatmap_points = new_heatmap_points
650
+
651
+ return {"image": image, "label": label, "heatmap_points": heatmap_points}
652
+
653
+ def augment_dict(self, sample):
654
+ image, label = sample["image"], sample["label"]
655
+ heatmap_points = sample["heatmaps"]
656
+ scale = sample["scale"]
657
+
658
+ rooms_padded, _, _, _ = self.resize_padded(label[[0]], self.size, mode="nearest", fill_cval=self.fill[0])
659
+ icons_padded, ratio, y_pad, x_pad = self.resize_padded(
660
+ label[[1]], self.size, mode="nearest", fill_cval=self.fill[1]
661
+ )
662
+ label = torch.cat((rooms_padded, icons_padded), dim=0)
663
+
664
+ new_heatmap_points = dict()
665
+ for junction_type, points in heatmap_points.items():
666
+ new_heatmap_points_per_type = []
667
+ for point in points:
668
+ # Indexing starts from 0 but when we multiply with the ratio we need to start from 0.
669
+ new_x = point[0] * ratio + x_pad
670
+ new_y = point[1] * ratio + y_pad
671
+ if new_y < 256 and new_x < 256 and new_y >= 0 and new_x >= 0:
672
+ # __import__('ipdb').set_trace()
673
+ new_heatmap_points_per_type.append([new_x, new_y])
674
+ new_heatmap_points[junction_type] = new_heatmap_points_per_type
675
+
676
+ heatmap_points = new_heatmap_points
677
+
678
+ return {"image": image, "label": label, "heatmaps": heatmap_points, "scale": scale}
679
+
680
+ def resize_padded(self, img, new_shape, image=False, fill_cval=0, mode="nearest", aling_corners=None):
681
+ new_shape = torch.tensor([img.shape[0], new_shape[0], new_shape[1]], dtype=self.dtype)
682
+ old_shape = torch.tensor(img.shape, dtype=self.dtype)
683
+
684
+ ratio = (new_shape / old_shape).min()
685
+ img_s = torch.tensor(img.shape[1:], dtype=self.dtype)
686
+ interm_shape = (ratio * img_s).ceil()
687
+
688
+ interm_shape = [interm_shape[0], interm_shape[1]]
689
+
690
+ img = img.unsqueeze(0)
691
+ interm_img = torch.nn.functional.interpolate(img, size=interm_shape, mode=mode, align_corners=aling_corners)
692
+ interm_img = interm_img.squeeze(0)
693
+
694
+ a = (interm_img.shape[0], self.size[0], self.size[1])
695
+
696
+ new_img = torch.full(a, fill_cval)
697
+
698
+ x_pad = int((self.width - interm_img.shape[1]) / 2)
699
+ y_pad = int((self.height - interm_img.shape[2]) / 2)
700
+
701
+ new_img[:, x_pad : interm_img.shape[1] + x_pad, y_pad : interm_img.shape[2] + y_pad] = interm_img
702
+
703
+ return new_img, ratio, x_pad, y_pad
data_preprocess/cubicasa5k/combine_json.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import json
3
+ import os
4
+ import shutil
5
+ from pathlib import Path
6
+
7
+
8
+ def combine_json_files(input_pattern, data_path, split_type, output_file, start_image_id=0):
9
+ """
10
+ Combines multiple COCO-style JSON annotation files into a single file.
11
+
12
+ Args:
13
+ input_pattern: Glob pattern to match the input JSON files (e.g., "annotations/*.json")
14
+ output_file: Path to the output combined JSON file
15
+ """
16
+ # Initialize combined data structure
17
+ combined_data = {"images": [], "annotations": [], "categories": []}
18
+
19
+ # Track image and annotation IDs to avoid duplicates
20
+ annotation_ids_seen = set()
21
+
22
+ next_image_id = start_image_id
23
+ next_annotation_id = 0
24
+ skip_file_list = []
25
+ image_id_mapping = {}
26
+
27
+ # Find all matching JSON files
28
+ json_files = sorted(glob.glob(input_pattern))
29
+ print(f"Found {len(json_files)} JSON files to combine")
30
+
31
+ # Process each file
32
+ for i, json_file in enumerate(json_files):
33
+ print(f"Processing file {i + 1}/{len(json_files)}: {json_file}")
34
+
35
+ with open(json_file, "r") as f:
36
+ data = json.load(f)
37
+
38
+ # Store categories from the first file
39
+ if i == 0 and data.get("categories"):
40
+ combined_data["categories"] = data["categories"]
41
+
42
+ # empty annos
43
+ if len(data["annotations"]) == 0:
44
+ skip_file_list.append(data["images"][0]["id"])
45
+ continue
46
+
47
+ # Process images
48
+ for image in data.get("images", []):
49
+ if image["id"] not in image_id_mapping:
50
+ image_id_mapping[image["id"]] = next_image_id
51
+ else:
52
+ skip_file_list.append(image["id"])
53
+ continue
54
+ image["id"] = next_image_id
55
+ next_image_id += 1
56
+ image["file_name"] = str(image["id"]).zfill(5) + ".png"
57
+ org_file_name = os.path.basename(json_file).replace(".json", ".png")
58
+ if image["file_name"] != org_file_name and os.path.exists(f"{data_path}/{split_type}/{org_file_name}"):
59
+ shutil.move(
60
+ f"{data_path}/{split_type}/{org_file_name}", f"{data_path}/{split_type}/{image['file_name']}"
61
+ )
62
+ combined_data["images"].append(image)
63
+
64
+ # Process annotations
65
+ for annotation in data.get("annotations", []):
66
+ annotation["id"] = next_annotation_id
67
+ next_annotation_id += 1
68
+ annotation["image_id"] = image_id_mapping[annotation["image_id"]]
69
+
70
+ annotation_ids_seen.add(annotation["id"])
71
+ combined_data["annotations"].append(annotation)
72
+
73
+ # Write combined data to output file
74
+ output_path = Path(output_file)
75
+ output_path.parent.mkdir(exist_ok=True, parents=True)
76
+
77
+ with open(output_file, "w") as f:
78
+ json.dump(combined_data, f, indent=2)
79
+
80
+ with open(output_path.parent / f"{output_path.name.split('.')[0]}_image_id_mapping.json", "w") as f:
81
+ json.dump(image_id_mapping, f, indent=2)
82
+
83
+ if len(skip_file_list):
84
+ with open(output_path.parent / f"{output_path.name.split('.')[0]}_skipped.txt", "w") as f:
85
+ f.write("\n".join([str(x) for x in skip_file_list]))
86
+
87
+ print(f"Combined data written to {output_file}")
88
+ print(f"Total images: {len(combined_data['images'])}")
89
+ print(f"Total annotations: {len(combined_data['annotations'])}")
90
+ print(f"Total categories: {len(combined_data['categories'])}")
91
+ print(f"Skipped images: {len(skip_file_list)}")
92
+
93
+ return combined_data
94
+
95
+
96
+ if __name__ == "__main__":
97
+ import argparse
98
+
99
+ parser = argparse.ArgumentParser(description="Combine multiple COCO-style JSON annotation files")
100
+ parser.add_argument("--input", required=True, help="Glob pattern for input JSON files, e.g., 'annotations/*.json'")
101
+ parser.add_argument("--output", required=True, help="Output JSON file path")
102
+
103
+ args = parser.parse_args()
104
+
105
+ splits = ["train", "val", "test"]
106
+ for i, split in enumerate(splits):
107
+ if split == "train":
108
+ start_image_id = 0
109
+ else:
110
+ start_image_id += len(list(Path(f"{args.input}/{splits[i - 1]}").glob("*.png")))
111
+
112
+ combine_json_files(
113
+ f"{args.input}/annotations_json/{split}/*.json",
114
+ args.input,
115
+ split,
116
+ f"{args.output}/{split}.json",
117
+ start_image_id=start_image_id,
118
+ )
data_preprocess/cubicasa5k/create_coco_cc5k.py ADDED
@@ -0,0 +1,672 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import sys
5
+ from multiprocessing import Pool
6
+ from pathlib import Path
7
+
8
+ import cv2
9
+ import matplotlib.pyplot as plt
10
+ import numpy as np
11
+ from loaders import FloorplanSVG
12
+ from matplotlib.patches import Patch
13
+ from PIL import Image
14
+ from shapely.geometry import Polygon
15
+ from skimage import measure
16
+ from tqdm import tqdm
17
+
18
+ sys.path.append(str(Path(__file__).resolve().parent.parent.parent))
19
+
20
+ sys.path.append(str(Path(__file__).resolve().parent.parent))
21
+ from common_utils import resort_corners
22
+ from stru3d.stru3d_utils import type2id
23
+
24
+ #### ORIGINAL ROOM NAMES & ICON_NAMES ####
25
+ ROOM_NAMES = {
26
+ 0: "Background",
27
+ 1: "Outdoor",
28
+ 2: "Wall",
29
+ 3: "Kitchen",
30
+ 4: "Living Room",
31
+ 5: "Bed Room",
32
+ 6: "Bath",
33
+ 7: "Entry",
34
+ 8: "Railing",
35
+ 9: "Storage",
36
+ 10: "Garage",
37
+ 11: "Undefined",
38
+ }
39
+
40
+ ICON_NAMES = {
41
+ 0: "No Icon",
42
+ 1: "Window",
43
+ 2: "Door",
44
+ 3: "Closet",
45
+ 4: "Electrical Applience",
46
+ 5: "Toilet",
47
+ 6: "Sink",
48
+ 7: "Sauna Bench",
49
+ 8: "Fire Place",
50
+ 9: "Bathtub",
51
+ 10: "Chimney",
52
+ }
53
+
54
+
55
+ CC5K_2_S3D_MAPPING = {
56
+ 0: None, # "Background"
57
+ 1: type2id["balcony"], # "Outdoor" -> balcony (4)
58
+ 2: None, # "Wall" has no direct match
59
+ 3: type2id["kitchen"], # Kitchen -> kitchen (1)
60
+ 4: type2id["living room"], # Living Room -> living room (0)
61
+ 5: type2id["bedroom"], # Bed Room -> bedroom (2)
62
+ 6: type2id["bathroom"], # Bath -> bathroom (3)
63
+ 7: 18, # 'Entry' has no direct match
64
+ 8: 19, # "Railing" has no direct match
65
+ 9: type2id["store room"], # Storage -> store room (9)
66
+ 10: type2id["garage"], # Garage -> garage (14)
67
+ 11: type2id["undefined"], # Undefined -> undefined (15)
68
+ 12: type2id["window"], # Window -> window (17)
69
+ 13: type2id["door"], # Door -> door (16)
70
+ }
71
+
72
+ CC5K_MAPPING = {
73
+ 0: None,
74
+ 1: 0, # Outdoor
75
+ 2: 1, # Wall
76
+ 3: 2, # Kitchen
77
+ 4: 3, # Living Room
78
+ 5: 4, # Bed Room
79
+ 6: 5, # Bath
80
+ 7: 6, # Entry
81
+ 8: 1, # Railing -> Wall
82
+ 9: 7, # Storage
83
+ 10: 8, # Garage
84
+ 11: 9, # Undefined
85
+ 12: 10, # Window
86
+ 13: 11, # Door
87
+ }
88
+
89
+ CC5K_MAPPING_2 = {
90
+ 0: None,
91
+ 1: 0, # Outdoor
92
+ 2: None, # Wall
93
+ 3: 1, # Kitchen
94
+ 4: 2, # Living Room
95
+ 5: 3, # Bed Room
96
+ 6: 4, # Bath
97
+ 7: 5, # Entry
98
+ 8: None, # Railing -> Wall
99
+ 9: 6, # Storage
100
+ 10: 7, # Garage
101
+ 11: 8, # Undefined
102
+ 12: 9, # Window
103
+ 13: 10, # Door
104
+ }
105
+
106
+ CC5K_CLASS_MAPPING = {
107
+ "Outdoor": 0,
108
+ "Wall, Railing": 1,
109
+ "Kitchen": 2,
110
+ "Living Room": 3,
111
+ "Bed Room": 4,
112
+ "Bath": 5,
113
+ "Entry": 6,
114
+ "Storage": 7,
115
+ "Garage": 8,
116
+ "Undefined": 9,
117
+ "Window": 10,
118
+ "Door": 11,
119
+ }
120
+
121
+ CC5K_CLASS_MAPPING_2 = {
122
+ "Outdoor": 0,
123
+ "Kitchen": 1,
124
+ "Living Room": 2,
125
+ "Bed Room": 3,
126
+ "Bath": 4,
127
+ "Entry": 5,
128
+ "Storage": 6,
129
+ "Garage": 7,
130
+ "Undefined": 8,
131
+ "Window": 9,
132
+ "Door": 10,
133
+ }
134
+
135
+ CLASS_MAPPING = {
136
+ "living room": 0,
137
+ "kitchen": 1,
138
+ "bedroom": 2,
139
+ "bathroom": 3,
140
+ "balcony": 4,
141
+ "corridor": 5,
142
+ "dining room": 6,
143
+ "study": 7,
144
+ "studio": 8,
145
+ "store room": 9,
146
+ "garden": 10,
147
+ "laundry room": 11,
148
+ "office": 12,
149
+ "basement": 13,
150
+ "garage": 14,
151
+ "undefined": 15,
152
+ "door": 16,
153
+ "window": 17,
154
+ "entry": 18,
155
+ "railing": 19,
156
+ }
157
+
158
+
159
+ def fill_holes_in_mask(binary_mask):
160
+ """
161
+ Fill 0-pixels in a binary mask that are completely surrounded by 1-pixels.
162
+
163
+ Args:
164
+ binary_mask (numpy.ndarray): Binary mask with 0 and 1 values.
165
+
166
+ Returns:
167
+ numpy.ndarray: Binary mask with holes filled.
168
+ """
169
+ # Ensure the mask is binary (0 and 1)
170
+ binary_mask = (binary_mask > 0).astype(np.uint8)
171
+
172
+ # Apply dilation
173
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (15, 15))
174
+ binary_mask = cv2.dilate(binary_mask, kernel, iterations=1)
175
+
176
+ # Find contours in the mask
177
+ contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
178
+
179
+ # Fill the contours
180
+ filled_mask = binary_mask.copy()
181
+ cv2.fillPoly(filled_mask, contours, 1)
182
+
183
+ return filled_mask
184
+
185
+
186
+ def close_contour(contour):
187
+ if not np.array_equal(contour[0], contour[-1]):
188
+ contour = np.vstack((contour, contour[0]))
189
+ return contour
190
+
191
+
192
+ def binary_mask_to_polygon(binary_mask, tolerance=0):
193
+ """Converts a binary mask to COCO polygon representation
194
+ Ref: https://github.com/waspinator/pycococreator/blob/master/pycococreatortools/pycococreatortools.py
195
+
196
+ Args:
197
+ binary_mask: a 2D binary numpy array where '1's represent the object
198
+ tolerance: Maximum distance from original points of polygon to approximated
199
+ polygonal chain. If tolerance is 0, the original coordinate array is returned.
200
+
201
+ """
202
+ polygons = []
203
+ # pad mask to close contours of shapes which start and end at an edge
204
+ padded_binary_mask = np.pad(binary_mask, pad_width=1, mode="constant", constant_values=0)
205
+ contours = measure.find_contours(padded_binary_mask, 0.5)
206
+ contours = np.subtract(contours, 1)
207
+ for contour in contours:
208
+ contour = close_contour(contour)
209
+ contour = measure.approximate_polygon(contour, tolerance)
210
+ if len(contour) < 3:
211
+ continue
212
+ contour = np.flip(contour, axis=1)
213
+ segmentation = contour.ravel().tolist()
214
+ # after padding and subtracting 1 we may get -0.5 points in our segmentation
215
+ segmentation = [0 if i < 0 else i for i in segmentation]
216
+ polygons.append(segmentation)
217
+
218
+ return polygons
219
+
220
+
221
+ def extract_icon_cv2(mask, start_cls_id=11, skip_classes=[]):
222
+ room_ids = np.unique(mask)
223
+ room_polygons = []
224
+ new_mask = np.zeros(mask.shape)
225
+
226
+ # window, door
227
+ for room_id in room_ids:
228
+ if room_id in skip_classes:
229
+ continue
230
+ true_room_id = int(room_id) + start_cls_id
231
+ # Create binary mask for this room
232
+ room_mask = (mask == room_id).astype(np.uint8)
233
+ new_mask = np.where(room_mask, true_room_id, 0)
234
+
235
+ # Find contours using OpenCV
236
+ contours, _ = cv2.findContours(room_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
237
+
238
+ if contours:
239
+ # # Get the largest contour
240
+ # largest_contour = max(contours, key=cv2.contourArea)
241
+ for cnt in contours:
242
+ polygon = [tuple(point[0]) for point in cnt]
243
+ if len(polygon) < 3:
244
+ continue
245
+
246
+ poly = Polygon(polygon)
247
+ simplified_poly = poly.simplify(tolerance=0.5, preserve_topology=True)
248
+ simplified_poly = list(simplified_poly.exterior.coords)
249
+ room_polygons.append([simplified_poly, true_room_id])
250
+
251
+ return room_polygons, new_mask
252
+
253
+
254
+ def visualize_room_polygons(mask, room_polygons, class_names, save_path="cubicasa_debug.png", bg_polygons=None):
255
+ """
256
+ Visualize the extracted room polygons.
257
+
258
+ Args:
259
+ mask: Original segmentation mask
260
+ room_polygons: Dictionary of room polygons as returned by extract_room_polygons
261
+ figsize: Figure size for the plot
262
+ """
263
+ # Set figure size to exactly 256x256 pixels
264
+ dpi = 100 # Standard screen DPI
265
+ figsize = (mask.shape[1] / dpi, mask.shape[0] / dpi) # Convert pixels to inches
266
+
267
+ # Get unique classes from the mask
268
+ unique_classes = np.unique(mask)
269
+
270
+ # Create a discrete colormap
271
+ cmap = plt.cm.get_cmap("gist_ncar", 256) # nipy_spectral
272
+ # cmap = ListedColormap([cmap(x) for x in np.linspace(0, 1, int(20))])
273
+
274
+ fig = plt.figure(figsize=figsize)
275
+ ax = fig.add_axes([0, 0, 1, 1])
276
+ plt.imshow(mask, cmap=cmap, interpolation="nearest", alpha=0.6, vmin=0, vmax=20)
277
+
278
+ # Plot each room polygon
279
+ for polygon, room_cls in room_polygons:
280
+ polygon_array = np.array(polygon).copy()
281
+ # # flip y
282
+ # polygon_array[:, 1] = mask.shape[0] - polygon_array[:, 1] - 1
283
+ ax.plot(polygon_array[:, 0], polygon_array[:, 1], "k-", linewidth=2)
284
+
285
+ # Add room ID label at the centroid
286
+ centroid_x = np.mean(polygon_array[:, 0])
287
+ centroid_y = np.mean(polygon_array[:, 1])
288
+ ax.text(
289
+ centroid_x,
290
+ centroid_y,
291
+ str(room_cls),
292
+ fontsize=12,
293
+ ha="center",
294
+ va="center",
295
+ bbox=dict(facecolor="white", alpha=0.7),
296
+ )
297
+
298
+ if bg_polygons is not None:
299
+ # Plot each room polygon
300
+ for polygon, room_cls in bg_polygons:
301
+ polygon_array = np.array(polygon).copy()
302
+ # # flip y
303
+ # polygon_array[:, 1] = mask.shape[0] - polygon_array[:, 1] - 1
304
+ ax.plot(polygon_array[:, 0], polygon_array[:, 1], "c-", linewidth=2)
305
+
306
+ # Create custom legend elements
307
+ legend_elements = []
308
+ norm = np.linspace(0, 1, 21) # int(max(unique_classes))+1
309
+
310
+ for i, cls in enumerate(sorted(unique_classes)):
311
+ # if int(cls) == 0:
312
+ # continue
313
+ # Get the exact same color that imshow uses
314
+ color = cmap(norm[int(cls)])
315
+ # color = cmap(int(cls))
316
+
317
+ cls_name = f"{int(cls)}_{class_names[int(cls)]}"
318
+ # You can replace f"Class {cls}" with your actual class names if available
319
+ legend_elements.append(Patch(facecolor=color, edgecolor="black", label=f"{cls_name}", alpha=0.6))
320
+
321
+ # Add the legend to the plot
322
+ ax.legend(
323
+ handles=legend_elements,
324
+ loc="best",
325
+ title="Classes",
326
+ fontsize=20,
327
+ markerscale=4,
328
+ title_fontsize=28,
329
+ )
330
+
331
+ # plt.title('Room Polygons Extracted from Segmentation Mask')
332
+ plt.axis("equal")
333
+ plt.axis("off")
334
+ fig.savefig(save_path, bbox_inches="tight", pad_inches=0)
335
+ plt.close()
336
+
337
+
338
+ def config():
339
+ a = argparse.ArgumentParser(description="Generate coco format data for Structured3D")
340
+ a.add_argument(
341
+ "--data_root", default="Structured3D_panorama", type=str, help="path to raw Structured3D_panorama folder"
342
+ )
343
+ a.add_argument("--output", default="coco_cubicasa5k", type=str, help="path to output folder")
344
+ a.add_argument("--disable_wd2line", action="store_true")
345
+
346
+ args = a.parse_args()
347
+ return args
348
+
349
+
350
+ def save_image(image_path, output_path, mask=None):
351
+ """
352
+ ref: https://github.com/ultralytics/ultralytics/issues/339
353
+ """
354
+ img = Image.open(image_path).convert("RGB")
355
+ img.info.pop("icc_profile", None)
356
+
357
+ if mask is not None:
358
+ img_array = np.array(img)
359
+ if len(mask.shape) == 2 and len(img_array.shape) == 3:
360
+ mask = mask[:, :, np.newaxis]
361
+ masked_img = np.where(mask == 0, 255, img_array)
362
+ img = Image.fromarray(masked_img.astype(np.uint8))
363
+
364
+ img.save(output_path)
365
+
366
+
367
+ def remove_polygons_by_type(polygons, skip_types=[]):
368
+ new_room_polygons = []
369
+ for polygon, poly_type in polygons:
370
+ if poly_type in skip_types:
371
+ continue
372
+ new_room_polygons.append([polygon, poly_type])
373
+ return new_room_polygons
374
+
375
+
376
+ def merge_rooms_and_icons(room_polygons, icon_polygons):
377
+ new_icon_polygons = []
378
+ for poly, poly_type in icon_polygons:
379
+ new_icon_polygons.append([poly, poly_type + 11])
380
+
381
+ return room_polygons + new_icon_polygons
382
+
383
+
384
+ def create_coco_bounding_box(bb_x, bb_y, image_width, image_height, bound_pad=2):
385
+ bb_x = np.unique(bb_x)
386
+ bb_y = np.unique(bb_y)
387
+ bb_x_min = np.maximum(np.min(bb_x) - bound_pad, 0)
388
+ bb_y_min = np.maximum(np.min(bb_y) - bound_pad, 0)
389
+
390
+ bb_x_max = np.minimum(np.max(bb_x) + bound_pad, image_width - 1)
391
+ bb_y_max = np.minimum(np.max(bb_y) + bound_pad, image_height - 1)
392
+
393
+ bb_width = bb_x_max - bb_x_min
394
+ bb_height = bb_y_max - bb_y_min
395
+
396
+ coco_bb = [bb_x_min, bb_y_min, bb_width, bb_height]
397
+ return coco_bb
398
+
399
+
400
+ def process_floorplan(
401
+ image_set,
402
+ scene_id,
403
+ start_scene_id,
404
+ args,
405
+ save_dir,
406
+ annos_folder,
407
+ use_org_cc5k_classs=False,
408
+ vis_fp=False,
409
+ wd2line=False,
410
+ ):
411
+ if use_org_cc5k_classs:
412
+ class_mapping_dict = CC5K_MAPPING_2 # old: CC5K_MAPPING
413
+ class_to_index_dict = CC5K_CLASS_MAPPING_2
414
+ door_window_index = [10, 9]
415
+ else:
416
+ class_mapping_dict = CC5K_2_S3D_MAPPING
417
+ class_to_index_dict = CLASS_MAPPING
418
+ door_window_index = [16, 17]
419
+
420
+ mask = image_set["label"].numpy()
421
+ room_polygons = [[poly, poly_type] for poly, poly_type in zip(image_set["room_polygon"], image_set["room_type"])]
422
+ icon_polygons = [[poly, poly_type] for poly, poly_type in zip(image_set["icon_polygon"], image_set["icon_type"])]
423
+
424
+ image_height, image_width = mask.shape[1:]
425
+ coco_annotation_dict_list = []
426
+
427
+ # for storing
428
+ save_dict = prepare_dict(class_to_index_dict) # old: CC5K_CLASS_MAPPING
429
+
430
+ instance_id = 0
431
+ img_id = int(scene_id) + start_scene_id
432
+ img_dict = {}
433
+ img_dict["file_name"] = str(img_id).zfill(5) + ".png"
434
+ img_dict["id"] = img_id
435
+ img_dict["width"] = image_width
436
+ img_dict["height"] = image_height
437
+
438
+ if vis_fp:
439
+ os.makedirs(save_dir.rstrip("/") + "_aux", exist_ok=True)
440
+ visualize_room_polygons(
441
+ mask[0],
442
+ room_polygons,
443
+ list(ROOM_NAMES.values()),
444
+ save_path=f"{save_dir.rstrip('/') + '_aux'}/{str(img_id).zfill(5)}_room.png",
445
+ )
446
+ visualize_room_polygons(
447
+ mask[1],
448
+ icon_polygons,
449
+ list(ICON_NAMES.values()),
450
+ bg_polygons=room_polygons,
451
+ save_path=f"{save_dir.rstrip('/') + '_aux'}/{str(img_id).zfill(5)}_icon.png",
452
+ )
453
+
454
+ #### FILTER NON-USE TYPES
455
+ # DROP BG
456
+ room_skip_types = [0]
457
+ filtered_room_polygons = remove_polygons_by_type(room_polygons, skip_types=room_skip_types)
458
+ # visualize_room_polygons(mask[0], filtered_room_polygons, list(ROOM_NAMES.values()),
459
+ # save_path=f"{save_dir.rstrip('/') + '_aux'}/{str(img_id).zfill(5)}_room_filtered.png")
460
+
461
+ # Exclude all furnitures, excepts window, door
462
+ icon_skip_types = [0, *list(range(3, 11))]
463
+ filtered_icon_polygons = remove_polygons_by_type(icon_polygons, skip_types=icon_skip_types)
464
+ # visualize_room_polygons(mask[1], filtered_icon_polygons, list(ICON_NAMES.values()),
465
+ # bg_polygons=room_polygons, save_path=f"{save_dir.rstrip('/') + '_aux'}/{str(img_id).zfill(5)}_icon_filtered.png")
466
+
467
+ #### COMBINED
468
+ combined_polygons = merge_rooms_and_icons(filtered_room_polygons, filtered_icon_polygons)
469
+
470
+ filtered_mask1 = mask[0].copy()
471
+ filtered_mask1[np.isin(mask[0], room_skip_types)] = 0
472
+
473
+ filtered_mask2 = mask[1].copy()
474
+ filtered_mask2[np.isin(mask[1], icon_skip_types)] = 0
475
+ filtered_mask2[filtered_mask2 != 0] += 11
476
+
477
+ filtered_mask = np.where(filtered_mask2 != 0, filtered_mask2, filtered_mask1)
478
+
479
+ new_filtered_mask = filtered_mask.copy()
480
+ for src_type, dest_type in class_mapping_dict.items():
481
+ if dest_type is None:
482
+ continue
483
+ new_filtered_mask[filtered_mask == src_type] = dest_type + 1
484
+ # filtered_mask = new_filtered_mask
485
+
486
+ binary_mask = np.zeros_like(filtered_mask)
487
+ binary_mask = np.where((mask[0] + mask[1]) != 0, 1, 0).astype(np.uint8)
488
+ filled_mask = fill_holes_in_mask(binary_mask)
489
+ cv2.imwrite(
490
+ f"{save_dir.rstrip('/') + '_aux'}/{str(img_id).zfill(5) + '_mask.png'}", filled_mask.astype(np.uint8) * 255
491
+ )
492
+ # visualize_room_polygons(combined_mask, combined_polygons, list(ROOM_NAMES.values()) + list(ICON_NAMES.values()), save_path=f"{save_dir}/{str(img_id).zfill(5)}_combined.png")
493
+
494
+ save_image(
495
+ f"{args.data_root}/{image_set['folder']}/F1_scaled.png",
496
+ f"{save_dir}/{str(img_id).zfill(5) + '.png'}",
497
+ mask=filled_mask,
498
+ )
499
+ if vis_fp:
500
+ save_image(
501
+ f"{args.data_root}/{image_set['folder']}/F1_scaled.png",
502
+ f"{save_dir.rstrip('/') + '_aux'}/{str(img_id).zfill(5) + '_org.png'}",
503
+ mask=None,
504
+ )
505
+
506
+ output_polygon_list = []
507
+ combined_polygon_list = []
508
+ for poly_ind, (polygon, poly_type) in enumerate(combined_polygons):
509
+ poly_shapely = Polygon(polygon)
510
+ area = poly_shapely.area
511
+
512
+ org_poly_type = poly_type
513
+ poly_type = class_mapping_dict[poly_type]
514
+ if poly_type is None:
515
+ continue
516
+
517
+ if poly_type not in door_window_index and area < 100:
518
+ continue
519
+ if poly_type in door_window_index and area < 1:
520
+ continue
521
+
522
+ rectangle_shapely = poly_shapely.envelope
523
+ polygon = np.array(polygon)
524
+
525
+ ### here we convert door/window annotation into a single line
526
+ if poly_type in door_window_index and wd2line:
527
+ if polygon.shape[0] > 4:
528
+ if len(polygon) == 5 and (polygon[0] == polygon[-1]).all():
529
+ polygon = polygon[:-1] # drop last point since it is same as first
530
+ else:
531
+ bounding_rect = np.array(poly_shapely.minimum_rotated_rectangle.exterior.coords)
532
+ polygon = bounding_rect[:4]
533
+
534
+ assert polygon.shape[0] == 4
535
+ midp_1 = (polygon[0] + polygon[1]) / 2
536
+ midp_2 = (polygon[1] + polygon[2]) / 2
537
+ midp_3 = (polygon[2] + polygon[3]) / 2
538
+ midp_4 = (polygon[3] + polygon[0]) / 2
539
+
540
+ dist_1_3 = np.square(midp_1 - midp_3).sum()
541
+ dist_2_4 = np.square(midp_2 - midp_4).sum()
542
+ if dist_1_3 > dist_2_4:
543
+ polygon = np.row_stack([midp_1, midp_3])
544
+ else:
545
+ polygon = np.row_stack([midp_2, midp_4])
546
+
547
+ coco_seg_poly = []
548
+ poly_sorted = resort_corners(polygon)
549
+
550
+ for p in poly_sorted:
551
+ coco_seg_poly += list(p)
552
+
553
+ # Slightly wider bounding box
554
+ bb_x, bb_y = rectangle_shapely.exterior.xy
555
+ coco_bb = create_coco_bounding_box(bb_x, bb_y, image_width, image_height, bound_pad=2)
556
+
557
+ coco_annotation_dict = {
558
+ "segmentation": [coco_seg_poly],
559
+ "area": area,
560
+ "iscrowd": 0,
561
+ "image_id": img_id,
562
+ "bbox": coco_bb,
563
+ "category_id": poly_type,
564
+ "id": instance_id,
565
+ }
566
+ coco_annotation_dict_list.append(coco_annotation_dict)
567
+ instance_id += 1
568
+
569
+ combined_polygon_list.append([np.array(coco_seg_poly).reshape(-1, 2), org_poly_type])
570
+ output_polygon_list.append([np.array(coco_seg_poly).reshape(-1, 2), poly_type + 1])
571
+
572
+ #### end split_file loop
573
+ save_dict["images"].append(img_dict)
574
+ save_dict["annotations"] += coco_annotation_dict_list
575
+
576
+ json_path = f"{annos_folder}/{str(img_id).zfill(5) + '.json'}"
577
+ with open(json_path, "w") as f:
578
+ json.dump(save_dict, f)
579
+
580
+ if vis_fp:
581
+ visualize_room_polygons(
582
+ filtered_mask,
583
+ combined_polygon_list,
584
+ list(ROOM_NAMES.values()) + ["window", "door"],
585
+ save_path=f"{save_dir.rstrip('/') + '_aux'}/{str(img_id).zfill(5)}_combined.png",
586
+ )
587
+ visualize_room_polygons(
588
+ new_filtered_mask,
589
+ output_polygon_list,
590
+ ["null"] + list(class_to_index_dict.keys()),
591
+ save_path=f"{save_dir.rstrip('/') + '_aux'}/{str(img_id).zfill(5)}_final.png",
592
+ )
593
+
594
+
595
+ def prepare_dict(categories_dict):
596
+ save_dict = {"images": [], "annotations": [], "categories": []}
597
+ for key, value in categories_dict.items():
598
+ type_dict = {"supercategory": "room", "id": value, "name": key}
599
+ save_dict["categories"].append(type_dict)
600
+ return save_dict
601
+
602
+
603
+ if __name__ == "__main__":
604
+ args = config()
605
+
606
+ ### prepare
607
+ outFolder = args.output
608
+ if not os.path.exists(outFolder):
609
+ os.mkdir(outFolder)
610
+
611
+ annotation_outFolder = os.path.join(outFolder, "annotations_json")
612
+ if not os.path.exists(annotation_outFolder):
613
+ os.mkdir(annotation_outFolder)
614
+
615
+ annos_train_folder = os.path.join(annotation_outFolder, "train")
616
+ annos_val_folder = os.path.join(annotation_outFolder, "val")
617
+ annos_test_folder = os.path.join(annotation_outFolder, "test")
618
+ os.makedirs(annos_train_folder, exist_ok=True)
619
+ os.makedirs(annos_val_folder, exist_ok=True)
620
+ os.makedirs(annos_test_folder, exist_ok=True)
621
+
622
+ train_img_folder = os.path.join(outFolder, "train")
623
+ val_img_folder = os.path.join(outFolder, "val")
624
+ test_img_folder = os.path.join(outFolder, "test")
625
+
626
+ for img_folder in [train_img_folder, val_img_folder, test_img_folder]:
627
+ if not os.path.exists(img_folder):
628
+ os.mkdir(img_folder)
629
+
630
+ coco_train_json_path = os.path.join(annotation_outFolder, "train.json")
631
+ coco_val_json_path = os.path.join(annotation_outFolder, "val.json")
632
+ coco_test_json_path = os.path.join(annotation_outFolder, "test.json")
633
+
634
+ ### begin processing
635
+ start_scene_id = 3500 # following index of s3d data
636
+ split_set = ["train.txt", "val.txt", "test.txt"]
637
+ save_folders = [train_img_folder, val_img_folder, test_img_folder]
638
+ coco_json_paths = [coco_train_json_path, coco_val_json_path, coco_test_json_path]
639
+ annos_folders = [annos_train_folder, annos_val_folder, annos_test_folder]
640
+
641
+ def wrapper(scene_id):
642
+ image_set = dataset[scene_id]
643
+ process_floorplan(
644
+ image_set,
645
+ scene_id,
646
+ start_scene_id,
647
+ args,
648
+ save_dir,
649
+ annos_folder,
650
+ use_org_cc5k_classs=True,
651
+ vis_fp=scene_id < 100,
652
+ wd2line=not args.disable_wd2line,
653
+ )
654
+
655
+ def worker_init(dataset_obj):
656
+ # Store dataset as global to avoid pickling issues
657
+ global dataset
658
+ dataset = dataset_obj
659
+
660
+ for split_id, split_file in enumerate(split_set):
661
+ dataset = FloorplanSVG(args.data_root, split_file, format="txt", original_size=False)
662
+ save_dir = save_folders[split_id]
663
+ json_path = coco_json_paths[split_id]
664
+ print(f"############# {split_file}")
665
+
666
+ annos_folder = annos_folders[split_id]
667
+ num_processes = 16
668
+ with Pool(num_processes, initializer=worker_init, initargs=(dataset,)) as p:
669
+ indices = range(len(dataset))
670
+ list(tqdm(p.imap(wrapper, indices), total=len(dataset)))
671
+
672
+ start_scene_id += len(dataset)
data_preprocess/cubicasa5k/floorplan_extraction.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import glob
3
+ import json
4
+ import os
5
+ import sys
6
+ from multiprocessing import Pool
7
+ from pathlib import Path
8
+
9
+ import cv2
10
+ import numpy as np
11
+ from shapely.geometry import Polygon
12
+ from tqdm import tqdm
13
+
14
+ sys.path.append(str(Path(__file__).resolve().parent.parent))
15
+ from common_utils import resort_corners
16
+ from create_coco_cc5k import create_coco_bounding_box
17
+
18
+ from util.plot_utils import plot_semantic_rich_floorplan_opencv
19
+
20
+
21
+ def plot_floor(output_coco_polygons, categories_dict, img_w, img_h, save_path, door_window_index=[10, 9]):
22
+ gt_sem_rich = []
23
+ for j, (poly, poly_type) in enumerate(output_coco_polygons):
24
+ corners = np.array(poly).reshape(-1, 2).astype(np.int32)
25
+ # corners_flip_y = corners.copy()
26
+ # corners_flip_y[:,1] = 255 - corners_flip_y[:,1]
27
+ # corners = corners_flip_y
28
+ gt_sem_rich.append([corners, poly_type])
29
+ # plot_semantic_rich_floorplan_nicely(gt_sem_rich, save_path, prec=None, rec=None,
30
+ # plot_text=True, is_bw=False,
31
+ # door_window_index=door_window_index,
32
+ # img_w=img_w,
33
+ # img_h=img_h,
34
+ # semantics_label_mapping=get_dataset_class_labels(categories_dict),
35
+ # )
36
+ plot_semantic_rich_floorplan_opencv(
37
+ gt_sem_rich,
38
+ save_path,
39
+ img_w=img_w,
40
+ img_h=img_h,
41
+ door_window_index=door_window_index,
42
+ semantics_label_mapping=get_dataset_class_labels(categories_dict),
43
+ is_bw=False,
44
+ )
45
+
46
+
47
+ def prepare_dict(categories_dict):
48
+ save_dict = {"images": [], "annotations": [], "categories": categories_dict}
49
+ return save_dict
50
+
51
+
52
+ def extract_polygons_from_mask(binary_mask, output_mask_path):
53
+ """
54
+ Extract polygons from a binary mask where regions with value 1 are polygons
55
+ and background regions have value 0.
56
+
57
+ Args:
58
+ binary_mask (numpy.ndarray): Binary mask with shape (H, W), where 1 represents
59
+ the polygon regions and 0 represents the background.
60
+
61
+ Returns:
62
+ list: A list of polygons, where each polygon is represented as a list of (x, y) coordinates.
63
+ """
64
+ # Ensure the mask is binary (0 and 1)
65
+ binary_mask = (binary_mask > 0).astype(np.uint8)
66
+
67
+ # Find contours in the binary mask
68
+ contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
69
+
70
+ # Extract polygons from contours
71
+ polygons = []
72
+ for contour in contours:
73
+ # Approximate the contour to reduce the number of points
74
+ epsilon = 0.001 * cv2.arcLength(contour, True) # Adjust epsilon for more/less detail
75
+ approx_polygon = cv2.approxPolyDP(contour, epsilon, True)
76
+ polygons.append(approx_polygon.squeeze().tolist()) # Convert to list of (x, y) points
77
+
78
+ # Convert binary_mask to a 3-channel image to draw colored polylines
79
+ binary_mask_colored = cv2.cvtColor(binary_mask * 255, cv2.COLOR_GRAY2BGR)
80
+
81
+ # Plot polygons on the binary mask with green color
82
+ for polygon in polygons:
83
+ points = np.array(polygon, dtype=np.int32)
84
+ cv2.polylines(binary_mask_colored, [points], isClosed=True, color=(0, 0, 255), thickness=10)
85
+
86
+ cv2.imwrite(output_mask_path, binary_mask_colored)
87
+
88
+ return polygons
89
+
90
+
91
+ def read_polygons_from_json(json_file):
92
+ with open(json_file, "r") as f:
93
+ data = json.load(f)
94
+ category_dict = data["categories"]
95
+ polygons = [data["annotations"][i]["segmentation"][0] for i in range(len(data["annotations"]))]
96
+ poly_types = [data["annotations"][i]["category_id"] for i in range(len(data["annotations"]))]
97
+ source_misc = [data["annotations"][i] for i in range(len(data["annotations"]))]
98
+ source_polygons = [(polygons[i], poly_types[i]) for i in range(len(polygons))]
99
+
100
+ return source_polygons, source_misc, category_dict
101
+
102
+
103
+ def get_dataset_class_labels(category_dict):
104
+ return {category_dict[i]["id"]: category_dict[i]["name"] for i in range(len(category_dict))}
105
+
106
+
107
+ def check_all_window_door_inside(polygons, door_window_index):
108
+ flag = all([poly_type in door_window_index for _, poly_type in polygons])
109
+ return flag
110
+
111
+
112
+ def extract_region_and_annotation(
113
+ source_image,
114
+ source_annot_path,
115
+ region_polygons,
116
+ start_image_id,
117
+ output_image_dir="output",
118
+ output_annot_dir="annotations",
119
+ output_aux_dir="output_aux",
120
+ vis_aux=True,
121
+ ):
122
+ """
123
+ Extract regions of the floorplan from the source image based on polygons
124
+ and generate annotations.
125
+
126
+ Args:
127
+ source_image (numpy.ndarray): The source image (H, W, 3).
128
+ polygons (list): List of polygons, where each polygon is a list of (x, y) coordinates.
129
+ output_dir (str): Directory to save the extracted regions and annotations.
130
+
131
+ Returns:
132
+ list: A list of annotations for each extracted region.
133
+ """
134
+ door_window_index = [10, 9]
135
+ source_polygons, source_misc, categories_dict = read_polygons_from_json(source_annot_path)
136
+ source_img_id = os.path.basename(source_annot_path).split(".")[0].zfill(5)
137
+ if vis_aux:
138
+ gt_sem_rich_path = os.path.join(output_aux_dir, "{}_org_floor.png".format(source_img_id))
139
+ plot_floor(
140
+ source_polygons,
141
+ categories_dict,
142
+ source_image.shape[1],
143
+ source_image.shape[0],
144
+ gt_sem_rich_path,
145
+ door_window_index=door_window_index,
146
+ )
147
+ margin = 10
148
+ img_id = start_image_id
149
+
150
+ # each region polygon corresponds to an image
151
+ for i, polygon in enumerate(region_polygons):
152
+ instance_id = 0
153
+ output_coco_polygons = []
154
+ # Create a mask for the current polygon
155
+ mask = np.zeros(source_image.shape[:2], dtype=np.uint8)
156
+ points = np.array(polygon, dtype=np.int32)
157
+ cv2.fillPoly(mask, [points], 255)
158
+
159
+ # Crop the ROI to the bounding box of the polygon
160
+ x, y, w, h = cv2.boundingRect(points)
161
+ # Expand the bounding box by the margin
162
+ x_expanded = max(x - margin, 0)
163
+ y_expanded = max(y - margin, 0)
164
+ w_expanded = min(x + w + margin, source_image.shape[1]) - x_expanded
165
+ h_expanded = min(y + h + margin, source_image.shape[0]) - y_expanded
166
+
167
+ x, y, w, h = x_expanded, y_expanded, w_expanded, h_expanded
168
+ cropped_roi = source_image[y : y + h, x : x + w]
169
+
170
+ save_dict = prepare_dict(categories_dict)
171
+
172
+ # Create an annotation for the extracted region
173
+ img_dict = {}
174
+ img_dict["file_name"] = f"{str(img_id).zfill(5)}_{source_img_id}.png"
175
+ img_dict["id"] = img_id
176
+ img_dict["width"] = w
177
+ img_dict["height"] = h
178
+
179
+ # Save the cropped ROI
180
+ roi_filename = f"{output_image_dir}/{str(img_id).zfill(5)}_{source_img_id}.png"
181
+ cv2.imwrite(roi_filename, cropped_roi)
182
+
183
+ bounding_box = np.array([x, y, x + w, y + h])
184
+
185
+ # Convert source polygons to NumPy arrays for vectorized operations
186
+ source_polygons_np = [np.array(src_poly[0]).reshape(-1, 2) for src_poly in source_polygons]
187
+ assert len(source_polygons_np) == len(source_polygons)
188
+
189
+ coco_annotation_dict_list = []
190
+ # Iterate through the polygons and filter those inside the bounding box
191
+ for j, tmp in enumerate(source_polygons_np):
192
+ # Compute the bounding box of the current polygon
193
+ poly_bbox = np.hstack([np.min(tmp, axis=0), np.max(tmp, axis=0)])
194
+
195
+ # Check if the polygon is outside the bounding box
196
+ if np.any(poly_bbox[:2] < bounding_box[:2]) or np.any(poly_bbox[2:] > bounding_box[2:]):
197
+ continue
198
+
199
+ # Scale the polygon coordinates relative to the top-left corner of the bounding box
200
+ scaled_polygon = tmp - bounding_box[:2]
201
+
202
+ coco_seg_poly = []
203
+ poly_sorted = resort_corners(scaled_polygon)
204
+ # image = draw_polygon_on_image(image, poly_shapely, "test_poly.jpg")
205
+
206
+ for p in poly_sorted:
207
+ coco_seg_poly += list(p)
208
+
209
+ if len(scaled_polygon) == 2:
210
+ area = source_misc[j]["area"]
211
+ coco_bb = source_misc[j]["bbox"]
212
+ # shift the bounding box
213
+ coco_bb[0] -= bounding_box[0]
214
+ coco_bb[1] -= bounding_box[1]
215
+ else:
216
+ poly_shapely = Polygon(scaled_polygon)
217
+ area = poly_shapely.area
218
+ rectangle_shapely = poly_shapely.envelope
219
+
220
+ # Slightly wider bounding box
221
+ bb_x, bb_y = rectangle_shapely.exterior.xy
222
+ coco_bb = create_coco_bounding_box(bb_x, bb_y, w, h, bound_pad=2)
223
+
224
+ coco_annotation_dict = {
225
+ "segmentation": [coco_seg_poly],
226
+ "area": area,
227
+ "iscrowd": 0,
228
+ "image_id": img_id,
229
+ "bbox": coco_bb,
230
+ "category_id": source_polygons[j][1],
231
+ "id": instance_id,
232
+ }
233
+ coco_annotation_dict_list.append(coco_annotation_dict)
234
+ output_coco_polygons.append([coco_seg_poly, source_polygons[j][1]])
235
+
236
+ # Remove after obtaining the polygon
237
+ # source_polygons.pop(j)
238
+ # source_misc.pop(j)
239
+ instance_id += 1
240
+
241
+ # skip if just windows and doors are inside
242
+ if check_all_window_door_inside(output_coco_polygons, door_window_index):
243
+ instance_id -= len(coco_annotation_dict_list)
244
+ continue
245
+
246
+ save_dict["images"].append(img_dict)
247
+ save_dict["annotations"] += coco_annotation_dict_list
248
+
249
+ if vis_aux:
250
+ gt_sem_rich_path = os.path.join(
251
+ output_aux_dir, "{}_{}_floor.png".format(str(img_id).zfill(5), source_img_id)
252
+ )
253
+ plot_floor(
254
+ output_coco_polygons, categories_dict, w, h, gt_sem_rich_path, door_window_index=door_window_index
255
+ )
256
+
257
+ # Save annotations to a JSON file
258
+ json_path = f"{output_annot_dir}/{str(img_id).zfill(5)}_{source_img_id}.json"
259
+ with open(json_path, "w") as f:
260
+ json.dump(save_dict, f)
261
+
262
+ img_id += 1
263
+
264
+ start_image_id = img_id
265
+
266
+ return start_image_id
267
+
268
+
269
+ def config():
270
+ a = argparse.ArgumentParser(description="Generate coco format data for Structured3D")
271
+ a.add_argument(
272
+ "--data_root", default="Structured3D_panorama", type=str, help="path to raw Structured3D_panorama folder"
273
+ )
274
+ a.add_argument("--output", default="coco_cubicasa5k", type=str, help="path to output folder")
275
+
276
+ args = a.parse_args()
277
+ return args
278
+
279
+
280
+ # Example usage
281
+ if __name__ == "__main__":
282
+ args = config()
283
+
284
+ ### prepare
285
+ outFolder = args.output
286
+ if not os.path.exists(outFolder):
287
+ os.mkdir(outFolder)
288
+
289
+ annotation_outFolder = os.path.join(outFolder, "annotations_json")
290
+ if not os.path.exists(annotation_outFolder):
291
+ os.mkdir(annotation_outFolder)
292
+
293
+ annos_train_folder = os.path.join(annotation_outFolder, "train")
294
+ annos_val_folder = os.path.join(annotation_outFolder, "val")
295
+ annos_test_folder = os.path.join(annotation_outFolder, "test")
296
+ os.makedirs(annos_train_folder, exist_ok=True)
297
+ os.makedirs(annos_val_folder, exist_ok=True)
298
+ os.makedirs(annos_test_folder, exist_ok=True)
299
+
300
+ train_img_folder = os.path.join(outFolder, "train")
301
+ val_img_folder = os.path.join(outFolder, "val")
302
+ test_img_folder = os.path.join(outFolder, "test")
303
+
304
+ for img_folder in [train_img_folder, val_img_folder, test_img_folder]:
305
+ if not os.path.exists(img_folder):
306
+ os.mkdir(img_folder)
307
+
308
+ ### begin processing
309
+ start_image_id = 3500
310
+ save_folders = [train_img_folder, val_img_folder, test_img_folder]
311
+ annos_folders = [annos_train_folder, annos_val_folder, annos_test_folder]
312
+ splits = ["train", "val", "test"]
313
+
314
+ def wrapper(index):
315
+ image_path, annot_path, mask_path = packed_input_files[index]
316
+ cur_image_id = int(os.path.basename(image_path).split(".")[0])
317
+ binary_mask = cv2.imread(mask_path)[:, :, -1]
318
+ source_image = cv2.imread(image_path, cv2.IMREAD_COLOR)
319
+ # Extract polygons
320
+ region_polygons = extract_polygons_from_mask(
321
+ binary_mask, output_mask_path=f"{save_aux_path}/{str(cur_image_id).zfill(5)}_polylines.png"
322
+ )
323
+
324
+ return extract_region_and_annotation(
325
+ source_image,
326
+ annot_path,
327
+ region_polygons,
328
+ start_image_id + index * 10,
329
+ save_path,
330
+ save_anno_path,
331
+ save_aux_path,
332
+ vis_aux=True,
333
+ )
334
+
335
+ def worker_init(input_files_object):
336
+ # Store dataset as global to avoid pickling issues
337
+ global packed_input_files
338
+ packed_input_files = input_files_object
339
+
340
+ for i, split in enumerate(splits):
341
+ image_files = sorted(glob.glob(f"{args.data_root}/{split}/*.png"))
342
+ image_id_list = [os.path.basename(image_path).split(".")[0] for image_path in image_files]
343
+ anno_files = [f"{args.data_root}/annotations_json/{split}/{id_}.json" for id_ in image_id_list]
344
+ mask_files = [f"{args.data_root}/{split}_aux/{id_}_mask.png" for id_ in image_id_list]
345
+ save_path = save_folders[i]
346
+ save_anno_path = annos_folders[i]
347
+ save_aux_path = save_path.rstrip("/") + "_aux"
348
+ os.makedirs(save_aux_path, exist_ok=True)
349
+
350
+ # for j, (image_path, anno_path, mask_path) in enumerate(zip(image_files, anno_files, mask_files)):
351
+ # cur_image_id = int(os.path.basename(image_path).split('.')[0])
352
+ # binary_mask = cv2.imread(mask_path)[:,:,-1]
353
+ # source_image = cv2.imread(image_path, cv2.IMREAD_COLOR)
354
+
355
+ # # Extract polygons
356
+ # polygons = extract_polygons_from_mask(binary_mask, output_mask_path=f'{save_aux_path}/{str(cur_image_id).zfill(5)}_polylines.png')
357
+ # # # skip if only one polygon (floorplan)
358
+ # # if len(polygons) == 1:
359
+ # # print(f"Skipping {image_path} with only one polygon")
360
+ # # with open(anno_path, 'r') as f:
361
+ # # data = json.load(f)
362
+ # # # update image id
363
+ # # data['images'][0]['id'] = start_image_id
364
+ # # data['images'][0]["file_name"] = f'{str(start_image_id).zfill(5)}_{str(cur_image_id).zfill(5)}.png'
365
+ # # for anno in data['annotations']:
366
+ # # anno['image_id'] = start_image_id
367
+
368
+ # # with open(f"{save_anno_path}/{str(start_image_id).zfill(5)}_{str(cur_image_id).zfill(5)}.json", 'w') as f:
369
+ # # json.dump(data, f, indent=2)
370
+ # # shutil.copy(image_path, f"{save_path}/{str(start_image_id).zfill(5)}_{str(cur_image_id).zfill(5)}.png")
371
+
372
+ # # gt_sem_rich_path = os.path.join(save_aux_path, '{}_{}_floor.png'.format(str(start_image_id).zfill(5), str(cur_image_id).zfill(5)))
373
+ # # output_coco_polygons = [(x['segmentation'][0], x['category_id']) for x in data['annotations']]
374
+ # # plot_floor(output_coco_polygons, data['categories'], data['images'][0]['width'], data['images'][0]['height'], gt_sem_rich_path, door_window_index=[10, 9])
375
+
376
+ # # start_image_id += 1
377
+ # # continue
378
+
379
+ # # # Print the extracted polygons
380
+ # # print("Extracted polygons:")
381
+ # # for i, polygon in enumerate(polygons):
382
+ # # print(f"Polygon {i + 1}: {polygon}")
383
+
384
+ # start_image_id = extract_region_and_annotation(source_image,
385
+ # anno_path,
386
+ # polygons,
387
+ # start_image_id,
388
+ # output_image_dir=save_path,
389
+ # output_annot_dir=save_anno_path,
390
+ # output_aux_dir=save_aux_path,
391
+ # vis_aux=True)
392
+
393
+ packed_input_files = list(zip(image_files, anno_files, mask_files))
394
+ # for j in range(5):
395
+ # wrapper(j)
396
+ # exit(0)
397
+
398
+ num_processes = 16
399
+ with Pool(num_processes, initializer=worker_init, initargs=(packed_input_files,)) as p:
400
+ indices = [j for j in range(len(packed_input_files))]
401
+ list(tqdm(p.imap(wrapper, indices), total=len(indices)))
402
+
403
+ start_image_id += len(packed_input_files) * 10
data_preprocess/cubicasa5k/house.py ADDED
@@ -0,0 +1,1131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from xml.dom import minidom
3
+
4
+ import cv2
5
+ import numpy as np
6
+ from skimage.draw import polygon
7
+ from svg_utils import (
8
+ PolygonWall,
9
+ calc_distance,
10
+ get_direction,
11
+ get_gaussian2D,
12
+ get_icon,
13
+ get_icon_number,
14
+ get_points,
15
+ get_room_number,
16
+ )
17
+
18
+ all_rooms = {
19
+ "Background": 0, # Not in data. The default outside label
20
+ "Alcove": 1,
21
+ "Attic": 2,
22
+ "Ballroom": 3,
23
+ "Bar": 4,
24
+ "Basement": 5,
25
+ "Bath": 6,
26
+ "Bedroom": 7,
27
+ "Below150cm": 8,
28
+ "CarPort": 9,
29
+ "Church": 10,
30
+ "Closet": 11,
31
+ "ConferenceRoom": 12,
32
+ "Conservatory": 13,
33
+ "Counter": 14,
34
+ "Den": 15,
35
+ "Dining": 16,
36
+ "DraughtLobby": 17,
37
+ "DressingRoom": 18,
38
+ "EatingArea": 19,
39
+ "Elevated": 20,
40
+ "Elevator": 21,
41
+ "Entry": 22,
42
+ "ExerciseRoom": 23,
43
+ "Garage": 24,
44
+ "Garbage": 25,
45
+ "Hall": 26,
46
+ "HallWay": 27,
47
+ "HotTub": 28,
48
+ "Kitchen": 29,
49
+ "Library": 30,
50
+ "LivingRoom": 31,
51
+ "Loft": 32,
52
+ "Lounge": 33,
53
+ "MediaRoom": 34,
54
+ "MeetingRoom": 35,
55
+ "Museum": 36,
56
+ "Nook": 37,
57
+ "Office": 38,
58
+ "OpenToBelow": 39,
59
+ "Outdoor": 40,
60
+ "Pantry": 41,
61
+ "Reception": 42,
62
+ "RecreationRoom": 43,
63
+ "RetailSpace": 44,
64
+ "Room": 45,
65
+ "Sanctuary": 46,
66
+ "Sauna": 47,
67
+ "ServiceRoom": 48,
68
+ "ServingArea": 49,
69
+ "Skylights": 50,
70
+ "Stable": 51,
71
+ "Stage": 52,
72
+ "StairWell": 53,
73
+ "Storage": 54,
74
+ "SunRoom": 55,
75
+ "SwimmingPool": 56,
76
+ "TechnicalRoom": 57,
77
+ "Theatre": 58,
78
+ "Undefined": 59,
79
+ "UserDefined": 60,
80
+ "Utility": 61,
81
+ "Wall": 62,
82
+ "Railing": 63,
83
+ "Stairs": 64,
84
+ }
85
+
86
+ rooms_selected = {
87
+ "Alcove": 11,
88
+ "Attic": 11,
89
+ "Ballroom": 11,
90
+ "Bar": 11,
91
+ "Basement": 11,
92
+ "Bath": 6,
93
+ "Bedroom": 5,
94
+ "CarPort": 10,
95
+ "Church": 11,
96
+ "Closet": 9,
97
+ "ConferenceRoom": 11,
98
+ "Conservatory": 11,
99
+ "Counter": 11,
100
+ "Den": 11,
101
+ "Dining": 4,
102
+ "DraughtLobby": 7,
103
+ "DressingRoom": 9,
104
+ "EatingArea": 4,
105
+ "Elevated": 11,
106
+ "Elevator": 11,
107
+ "Entry": 7,
108
+ "ExerciseRoom": 11,
109
+ "Garage": 10,
110
+ "Garbage": 11,
111
+ "Hall": 11,
112
+ "HallWay": 7,
113
+ "HotTub": 11,
114
+ "Kitchen": 3,
115
+ "Library": 11,
116
+ "LivingRoom": 4,
117
+ "Loft": 11,
118
+ "Lounge": 4,
119
+ "MediaRoom": 11,
120
+ "MeetingRoom": 11,
121
+ "Museum": 11,
122
+ "Nook": 11,
123
+ "Office": 11,
124
+ "OpenToBelow": 11,
125
+ "Outdoor": 1,
126
+ "Pantry": 11,
127
+ "Reception": 11,
128
+ "RecreationRoom": 11,
129
+ "RetailSpace": 11,
130
+ "Room": 11,
131
+ "Sanctuary": 11,
132
+ "Sauna": 6,
133
+ "ServiceRoom": 11,
134
+ "ServingArea": 11,
135
+ "Skylights": 11,
136
+ "Stable": 11,
137
+ "Stage": 11,
138
+ "StairWell": 11,
139
+ "Storage": 9,
140
+ "SunRoom": 11,
141
+ "SwimmingPool": 11,
142
+ "TechnicalRoom": 11,
143
+ "Theatre": 11,
144
+ "Undefined": 11,
145
+ "UserDefined": 11,
146
+ "Utility": 11,
147
+ "Background": 0, # Not in data. The default outside label
148
+ "Wall": 2,
149
+ "Railing": 8,
150
+ }
151
+
152
+ room_name_map = {
153
+ "Alcove": "Room",
154
+ "Attic": "Room",
155
+ "Ballroom": "Room",
156
+ "Bar": "Room",
157
+ "Basement": "Room",
158
+ "Bath": "Bath",
159
+ "Bedroom": "Bedroom",
160
+ "Below150cm": "Room",
161
+ "CarPort": "Garage",
162
+ "Church": "Room",
163
+ "Closet": "Storage",
164
+ "ConferenceRoom": "Room",
165
+ "Conservatory": "Room",
166
+ "Counter": "Room",
167
+ "Den": "Room",
168
+ "Dining": "Dining",
169
+ "DraughtLobby": "Entry",
170
+ "DressingRoom": "Storage",
171
+ "EatingArea": "Dining",
172
+ "Elevated": "Room",
173
+ "Elevator": "Room",
174
+ "Entry": "Entry",
175
+ "ExerciseRoom": "Room",
176
+ "Garage": "Garage",
177
+ "Garbage": "Room",
178
+ "Hall": "Room",
179
+ "HallWay": "Entry",
180
+ "HotTub": "Room",
181
+ "Kitchen": "Kitchen",
182
+ "Library": "Room",
183
+ "LivingRoom": "LivingRoom",
184
+ "Loft": "Room",
185
+ "Lounge": "LivingRoom",
186
+ "MediaRoom": "Room",
187
+ "MeetingRoom": "Room",
188
+ "Museum": "Room",
189
+ "Nook": "Room",
190
+ "Office": "Room",
191
+ "OpenToBelow": "Room",
192
+ "Outdoor": "Outdoor",
193
+ "Pantry": "Room",
194
+ "Reception": "Room",
195
+ "RecreationRoom": "Room",
196
+ "RetailSpace": "Room",
197
+ "Room": "Room",
198
+ "Sanctuary": "Room",
199
+ "Sauna": "Bath",
200
+ "ServiceRoom": "Room",
201
+ "ServingArea": "Room",
202
+ "Skylights": "Room",
203
+ "Stable": "Room",
204
+ "Stage": "Room",
205
+ "StairWell": "Room",
206
+ "Storage": "Storage",
207
+ "SunRoom": "Room",
208
+ "SwimmingPool": "Room",
209
+ "TechnicalRoom": "Room",
210
+ "Theatre": "Room",
211
+ "Undefined": "Room",
212
+ "UserDefined": "Room",
213
+ "Utility": "Room",
214
+ "Wall": "Wall",
215
+ "Railing": "Railing",
216
+ "Background": "Background",
217
+ } # Not in data. The default outside label
218
+
219
+ all_icons = {
220
+ "Empty": 0,
221
+ "Window": 1,
222
+ "Door": 2,
223
+ "BaseCabinet": 3,
224
+ "BaseCabinetRound": 4,
225
+ "BaseCabinetTriangle": 5,
226
+ "Bathtub": 6,
227
+ "BathtubRound": 7,
228
+ "Chimney": 8,
229
+ "Closet": 9,
230
+ "ClosetRound": 10,
231
+ "ClosetTriangle": 11,
232
+ "CoatCloset": 12,
233
+ "CoatRack": 13,
234
+ "CornerSink": 14,
235
+ "CounterTop": 15,
236
+ "DoubleSink": 16,
237
+ "DoubleSinkRight": 17,
238
+ "ElectricalAppliance": 18,
239
+ "Fireplace": 19,
240
+ "FireplaceCorner": 20,
241
+ "FireplaceRound": 21,
242
+ "GasStove": 22,
243
+ "Housing": 23,
244
+ "Jacuzzi": 24,
245
+ "PlaceForFireplace": 25,
246
+ "PlaceForFireplaceCorner": 26,
247
+ "PlaceForFireplaceRound": 27,
248
+ "RoundSink": 28,
249
+ "SaunaBenchHigh": 29,
250
+ "SaunaBenchLow": 30,
251
+ "SaunaBenchMid": 31,
252
+ "Shower": 32,
253
+ "ShowerCab": 33,
254
+ "ShowerScreen": 34,
255
+ "ShowerScreenRoundLeft": 35,
256
+ "ShowerScreenRoundRight": 36,
257
+ "SideSink": 37,
258
+ "Sink": 38,
259
+ "Toilet": 39,
260
+ "Urinal": 40,
261
+ "WallCabinet": 41,
262
+ "WaterTap": 42,
263
+ "WoodStove": 43,
264
+ "Misc": 44,
265
+ "SaunaBench": 45,
266
+ "SaunaStove": 46,
267
+ "WashingMachine": 47,
268
+ "IntegratedStove": 48,
269
+ "Dishwasher": 49,
270
+ "GeneralAppliance": 50,
271
+ "ShowerPlatform": 51,
272
+ }
273
+
274
+ icons_selected = {
275
+ "Window": 1,
276
+ "Door": 2,
277
+ "Closet": 3,
278
+ "ClosetRound": 3,
279
+ "ClosetTriangle": 3,
280
+ "CoatCloset": 3,
281
+ "CoatRack": 3,
282
+ "CounterTop": 3,
283
+ "Housing": 3,
284
+ "ElectricalAppliance": 4,
285
+ "WoodStove": 4,
286
+ "GasStove": 4,
287
+ "Toilet": 5,
288
+ "Urinal": 5,
289
+ "SideSink": 6,
290
+ "Sink": 6,
291
+ "RoundSink": 6,
292
+ "CornerSink": 6,
293
+ "DoubleSink": 6,
294
+ "DoubleSinkRight": 6,
295
+ "WaterTap": 6,
296
+ "SaunaBenchHigh": 7,
297
+ "SaunaBenchLow": 7,
298
+ "SaunaBenchMid": 7,
299
+ "SaunaBench": 7,
300
+ "Fireplace": 8,
301
+ "FireplaceCorner": 8,
302
+ "FireplaceRound": 8,
303
+ "PlaceForFireplace": 8,
304
+ "PlaceForFireplaceCorner": 8,
305
+ "PlaceForFireplaceRound": 8,
306
+ "Bathtub": 9,
307
+ "BathtubRound": 9,
308
+ "Chimney": 10,
309
+ "Misc": None,
310
+ "BaseCabinetRound": None,
311
+ "BaseCabinetTriangle": None,
312
+ "BaseCabinet": None,
313
+ "WallCabinet": None,
314
+ "Shower": None,
315
+ "ShowerCab": None,
316
+ "ShowerPlatform": None,
317
+ "ShowerScreen": None,
318
+ "ShowerScreenRoundRight": None,
319
+ "ShowerScreenRoundLeft": None,
320
+ "Jacuzzi": None,
321
+ }
322
+
323
+ icon_name_map = {
324
+ "Window": "Window",
325
+ "Door": "Door",
326
+ "Closet": "Closet",
327
+ "ClosetRound": "Closet",
328
+ "ClosetTriangle": "Closet",
329
+ "CoatCloset": "Closet",
330
+ "CoatRack": "Closet",
331
+ "CounterTop": "Closet",
332
+ "Housing": "Closet",
333
+ "ElectricalAppliance": "ElectricalAppliance",
334
+ "WoodStove": "ElectricalAppliance",
335
+ "GasStove": "ElectricalAppliance",
336
+ "SaunaStove": "ElectricalAppliance",
337
+ "Toilet": "Toilet",
338
+ "Urinal": "Toilet",
339
+ "SideSink": "Sink",
340
+ "Sink": "Sink",
341
+ "RoundSink": "Sink",
342
+ "CornerSink": "Sink",
343
+ "DoubleSink": "Sink",
344
+ "DoubleSinkRight": "Sink",
345
+ "WaterTap": "Sink",
346
+ "SaunaBenchHigh": "SaunaBench",
347
+ "SaunaBenchLow": "SaunaBench",
348
+ "SaunaBenchMid": "SaunaBench",
349
+ "SaunaBench": "SaunaBench",
350
+ "Fireplace": "Fireplace",
351
+ "FireplaceCorner": "Fireplace",
352
+ "FireplaceRound": "Fireplace",
353
+ "PlaceForFireplace": "Fireplace",
354
+ "PlaceForFireplaceCorner": "Fireplace",
355
+ "PlaceForFireplaceRound": "Fireplace",
356
+ "Bathtub": "Bathtub",
357
+ "BathtubRound": "Bathtub",
358
+ "Chimney": "Chimney",
359
+ "Misc": None,
360
+ "BaseCabinetRound": None,
361
+ "BaseCabinetTriangle": None,
362
+ "BaseCabinet": None,
363
+ "WallCabinet": None,
364
+ "Shower": "None",
365
+ "ShowerCab": "None",
366
+ "ShowerPlatform": "None",
367
+ "ShowerScreen": None,
368
+ "ShowerScreenRoundRight": None,
369
+ "ShowerScreenRoundLeft": None,
370
+ "Jacuzzi": None,
371
+ "WashingMachine": None,
372
+ "IntegratedStove": "ElectricalAppliance",
373
+ "Dishwasher": "ElectricalAppliance",
374
+ "GeneralAppliance": "ElectricalAppliance",
375
+ }
376
+
377
+
378
+ def complete_polygons(polygons, polygon_types):
379
+ new_polygons = []
380
+ new_types = []
381
+ for poly, poly_type in zip(polygons, polygon_types):
382
+ if len(poly) < 3:
383
+ print(f"Class {poly_type} has less than 3 points. Skipped!")
384
+ continue
385
+ poly_array = np.array(poly)
386
+ t = copy.copy(poly)
387
+ # append the beginning point
388
+ if len(poly_array) > 2 and (poly_array[0] != poly_array[-1]).any():
389
+ t.append(poly[0])
390
+ new_polygons.append(t)
391
+ new_types.append(poly_type)
392
+
393
+ return new_polygons, new_types
394
+
395
+
396
+ class House:
397
+ def __init__(self, path, height, width, icon_list=icons_selected, room_list=rooms_selected):
398
+ self.height = height
399
+ self.width = width
400
+ shape = height, width
401
+ svg = minidom.parse(path)
402
+ self.walls = np.empty((height, width), dtype=np.uint8)
403
+ self.walls.fill(0)
404
+ self.wall_ids = np.empty((height, width), dtype=np.uint8)
405
+ self.wall_ids.fill(0)
406
+ self.icons = np.zeros((height, width), dtype=np.uint8)
407
+ # junction_id = 0
408
+ wall_id = 1
409
+ self.wall_ends = []
410
+ self.wall_objs = []
411
+ self.icon_types = []
412
+ self.room_types = []
413
+ self.icon_corners = {"upper_left": [], "upper_right": [], "lower_left": [], "lower_right": []}
414
+ self.opening_corners = {"left": [], "right": [], "up": [], "down": []}
415
+ self.representation = {"doors": [], "icons": [], "labels": [], "walls": []}
416
+
417
+ self.icon_areas = []
418
+ self.wall_coords = []
419
+ self.icon_coords = []
420
+
421
+ for e in svg.getElementsByTagName("g"):
422
+ try:
423
+ if e.getAttribute("id") == "Wall":
424
+ wall = PolygonWall(e, wall_id, shape)
425
+ wall.rr, wall.cc = self._clip_outside(wall.rr, wall.cc)
426
+ self.wall_objs.append(wall)
427
+ self.walls[wall.rr, wall.cc] = room_list["Wall"]
428
+ self.wall_ids[wall.rr, wall.cc] = wall_id
429
+ self.wall_ends.append(wall.end_points)
430
+
431
+ Y, X = self._clip_outside(wall.Y, wall.X)
432
+ self.wall_coords.append([(x, y) for x, y in zip(X, Y)])
433
+ self.room_types.append(room_list["Wall"])
434
+
435
+ wall_id += 1
436
+
437
+ if e.getAttribute("id") == "Railing":
438
+ wall = PolygonWall(e, wall_id, shape)
439
+ wall.rr, wall.cc = self._clip_outside(wall.rr, wall.cc)
440
+ self.wall_objs.append(wall)
441
+ self.walls[wall.rr, wall.cc] = room_list["Railing"]
442
+ self.wall_ids[wall.rr, wall.cc] = wall_id
443
+ self.wall_ends.append(wall.end_points)
444
+
445
+ Y, X = self._clip_outside(wall.Y, wall.X)
446
+ self.wall_coords.append([(x, y) for x, y in zip(X, Y)])
447
+ self.room_types.append(room_list["Railing"])
448
+
449
+ wall_id += 1
450
+
451
+ except ValueError as k:
452
+ if str(k) != "small wall":
453
+ raise k
454
+ continue
455
+
456
+ if e.getAttribute("id") == "Window":
457
+ X, Y = get_points(e)
458
+ rr, cc = polygon(X, Y)
459
+ cc, rr = self._clip_outside(cc, rr)
460
+ direction = get_direction(X, Y)
461
+ locs = np.column_stack((X, Y))
462
+ if direction == "H":
463
+ left_index = np.argmin(locs[:, 0])
464
+ left1 = locs[left_index]
465
+ locs = np.delete(locs, left_index, axis=0)
466
+ left_index = np.argmin(locs[:, 0])
467
+ left2 = locs[left_index]
468
+ right = np.delete(locs, left_index, axis=0)
469
+ left = np.array([left1, left2])
470
+
471
+ point_left = left.mean(axis=0)
472
+ point_right = right.mean(axis=0)
473
+ self.opening_corners["left"].append(point_left)
474
+ self.opening_corners["right"].append(point_right)
475
+
476
+ door_rep = [[list(point_left), list(point_right)], ["door", 1, 1]]
477
+ self.representation["doors"].append(door_rep)
478
+ else:
479
+ up_index = np.argmin(locs[:, 1])
480
+ up1 = locs[up_index]
481
+ locs = np.delete(locs, up_index, axis=0)
482
+ up_index = np.argmin(locs[:, 1])
483
+ up2 = locs[up_index]
484
+ down = np.delete(locs, up_index, axis=0)
485
+ up = np.array([up1, up2])
486
+
487
+ point_up = up.mean(axis=0)
488
+ point_down = down.mean(axis=0)
489
+ self.opening_corners["up"].append(point_up)
490
+ self.opening_corners["down"].append(point_down)
491
+
492
+ door_rep = [[list(point_up), list(point_down)], ["door", 1, 1]]
493
+ self.representation["doors"].append(door_rep)
494
+
495
+ self.icons[cc, rr] = 1
496
+ self.icon_types.append(1)
497
+
498
+ Y, X = self._clip_outside(Y, X)
499
+ self.icon_coords.append([(x, y) for x, y in zip(X, Y)])
500
+
501
+ if e.getAttribute("id") == "Door":
502
+ # How to reperesent empty door space
503
+ X, Y = get_points(e)
504
+ rr, cc = polygon(X, Y)
505
+ cc, rr = self._clip_outside(cc, rr)
506
+ direction = get_direction(X, Y)
507
+ locs = np.column_stack((X, Y))
508
+ if direction == "H":
509
+ left_index = np.argmin(locs[:, 0])
510
+ left1 = locs[left_index]
511
+ locs = np.delete(locs, left_index, axis=0)
512
+ left_index = np.argmin(locs[:, 0])
513
+ left2 = locs[left_index]
514
+ right = np.delete(locs, left_index, axis=0)
515
+ left = np.array([left1, left2])
516
+
517
+ point_left = left.mean(axis=0)
518
+ point_right = right.mean(axis=0)
519
+ self.opening_corners["left"].append(left.mean(axis=0))
520
+ self.opening_corners["right"].append(right.mean(axis=0))
521
+
522
+ door_rep = [[list(point_left), list(point_right)], ["door", 1, 1]]
523
+ self.representation["doors"].append(door_rep)
524
+ else:
525
+ up_index = np.argmin(locs[:, 1])
526
+ up1 = locs[up_index]
527
+ locs = np.delete(locs, up_index, axis=0)
528
+ up_index = np.argmin(locs[:, 1])
529
+ up2 = locs[up_index]
530
+ down = np.delete(locs, up_index, axis=0)
531
+ up = np.array([up1, up2])
532
+
533
+ point_up = up.mean(axis=0)
534
+ point_down = down.mean(axis=0)
535
+ self.opening_corners["up"].append(up.mean(axis=0))
536
+ self.opening_corners["down"].append(down.mean(axis=0))
537
+
538
+ door_rep = [[list(point_up), list(point_down)], ["door", 1, 1]]
539
+ self.representation["doors"].append(door_rep)
540
+
541
+ self.icons[cc, rr] = 2
542
+ self.icon_types.append(2)
543
+
544
+ Y, X = self._clip_outside(Y, X)
545
+ self.icon_coords.append([(x, y) for x, y in zip(X, Y)])
546
+
547
+ if "FixedFurniture " in e.getAttribute("class"):
548
+ num = get_icon_number(e, icon_list)
549
+ if num is not None:
550
+ rr, cc, X, Y = get_icon(e)
551
+ # only four corner icons
552
+ if len(X) == 4:
553
+ locs = np.column_stack((X, Y))
554
+ up_left_index = locs.sum(axis=1).argmin()
555
+ self.icon_corners["upper_left"].append(locs[up_left_index])
556
+ up_left = list(locs[up_left_index])
557
+ locs = np.delete(locs, up_left_index, axis=0)
558
+ down_right_index = locs.sum(axis=1).argmax()
559
+ self.icon_corners["lower_right"].append(locs[down_right_index])
560
+ down_right = list(locs[down_right_index])
561
+ locs = np.delete(locs, down_right_index, axis=0)
562
+ up_right_index = locs[:, 1].argmin()
563
+ self.icon_corners["upper_right"].append(locs[up_right_index])
564
+ locs = np.delete(locs, up_right_index, axis=0)
565
+ self.icon_corners["lower_left"].append(locs[0])
566
+
567
+ icon_name = e.getAttribute("class").replace("FixedFurniture ", "").split(" ")[0]
568
+ icon_name = icon_name_map[icon_name]
569
+
570
+ icon_rep = [[up_left, down_right], [icon_name, 1, 1]]
571
+ self.representation["icons"].append(icon_rep)
572
+
573
+ rr, cc = self._clip_outside(rr, cc)
574
+ self.icon_areas.append(len(rr))
575
+ self.icons[rr, cc] = num
576
+ self.icon_types.append(num)
577
+
578
+ Y, X = self._clip_outside(Y, X)
579
+ self.icon_coords.append([(x, y) for x, y in zip(X, Y)])
580
+
581
+ if "Space " in e.getAttribute("class"):
582
+ num = get_room_number(e, room_list)
583
+ # rr, cc = get_polygon(e)
584
+ X, Y = get_points(e)
585
+ rr, cc = polygon(Y, X)
586
+ if len(rr) != 0:
587
+ rr, cc = self._clip_outside(rr, cc)
588
+ if len(rr) != 0 and len(cc) != 0:
589
+ self.walls[rr, cc] = num
590
+ self.room_types.append(num)
591
+
592
+ Y, X = self._clip_outside(Y, X)
593
+ self.wall_coords.append([(x, y) for x, y in zip(X, Y)])
594
+
595
+ rr_mean = int(round(np.mean(rr)))
596
+ cc_mean = int(round(np.mean(cc)))
597
+ center_box = [[rr_mean - 10, cc_mean - 10], [rr_mean + 10, cc_mean + 10]]
598
+ room_name = e.getAttribute("class").replace("Space ", "").split(" ")[0]
599
+ room_name = room_name_map[room_name]
600
+ self.representation["labels"].append([center_box, [room_name, 1, 1]])
601
+
602
+ # if "Stairs" in e.getAttribute("class"):
603
+ # for c in e.childNodes:
604
+ # if c.getAttribute("class") in ["Flight", "Winding"]:
605
+ # num = room_list["Stairs"]
606
+ # rr, cc = get_polygon(c)
607
+ # if len(rr) != 0:
608
+ # rr, cc = self._clip_outside(rr, cc)
609
+ # if len(rr) != 0 and len(cc) != 0:
610
+ # self.walls[rr, cc] = num
611
+ # self.room_types.append(num)
612
+
613
+ # rr_mean = int(round(np.mean(rr)))
614
+ # cc_mean = int(round(np.mean(cc)))
615
+ # center_box = [[rr_mean-10, cc_mean-10], [rr_mean+10, cc_mean+10]]
616
+ # room_name = "Stairs"
617
+ # # room_name = room_name_map[room_name]
618
+ # self.representation['labels'].append([center_box, [room_name, 1, 1]])
619
+
620
+ self.avg_wall_width = self.get_avg_wall_width()
621
+
622
+ self.new_walls = self.connect_walls(self.wall_objs)
623
+
624
+ for w in self.new_walls:
625
+ w.change_end_points()
626
+
627
+ for w in self.pillar_walls:
628
+ self.new_walls.append(w)
629
+
630
+ self.points = self.lines_to_points(self.width, self.height, self.new_walls, self.avg_wall_width)
631
+ self.points = self.merge_joints(self.points, self.avg_wall_width)
632
+
633
+ # walls to representation
634
+ for w in self.new_walls:
635
+ end_points = w.end_points.round().astype("int").tolist()
636
+ if w.name == "Wall":
637
+ self.representation["walls"].append([end_points, ["wall", 1, 1]])
638
+ else:
639
+ self.representation["walls"].append([end_points, ["wall", 2, 1]])
640
+
641
+ # append begining point at last pos
642
+ print("Complete room coords")
643
+ self.wall_coords, self.room_types = complete_polygons(self.wall_coords, self.room_types)
644
+ print("Complete icon coords")
645
+ self.icon_coords, self.icon_types = complete_polygons(self.icon_coords, self.icon_types)
646
+
647
+ def get_coords_and_labels(self):
648
+ assert len(self.wall_coords) == len(self.room_types)
649
+ assert len(self.icon_coords) == len(self.icon_types)
650
+ return self.wall_coords, self.room_types, self.icon_coords, self.icon_types
651
+
652
+ def get_tensor(self):
653
+ heatmaps = self.get_heatmaps()
654
+ wall_t = np.expand_dims(self.walls, axis=0)
655
+ icon_t = np.expand_dims(self.icons, axis=0)
656
+ tensor = np.concatenate((heatmaps, wall_t, icon_t), axis=0)
657
+
658
+ return tensor
659
+
660
+ def get_segmentation_tensor(self):
661
+ wall_t = np.expand_dims(self.walls, axis=0)
662
+ icon_t = np.expand_dims(self.icons, axis=0)
663
+ tensor = np.concatenate((wall_t, icon_t), axis=0)
664
+
665
+ return tensor
666
+
667
+ def get_heatmap_dict(self):
668
+ # init dict
669
+ heatmaps = {}
670
+ for i in range(21):
671
+ heatmaps[i] = []
672
+
673
+ for p in self.points:
674
+ cord, _, p_type = p
675
+ x = int(np.round(cord[0]))
676
+ y = int(np.round(cord[1]))
677
+ channel = self.get_number(p_type)
678
+ if y < self.height and x < self.width:
679
+ heatmaps[channel - 1] = heatmaps[channel - 1] + [(x, y)]
680
+
681
+ channel = 13
682
+ for i in self.opening_corners["left"]:
683
+ y = int(i[1])
684
+ x = int(i[0])
685
+ if y < self.height and x < self.width:
686
+ heatmaps[channel] = heatmaps[channel] + [(x, y)]
687
+ channel += 1
688
+ for i in self.opening_corners["right"]:
689
+ y = int(i[1])
690
+ x = int(i[0])
691
+ if y < self.height and x < self.width:
692
+ heatmaps[channel] = heatmaps[channel] + [(x, y)]
693
+ channel += 1
694
+ for i in self.opening_corners["up"]:
695
+ y = int(i[1])
696
+ x = int(i[0])
697
+ if y < self.height and x < self.width:
698
+ heatmaps[channel] = heatmaps[channel] + [(x, y)]
699
+ channel += 1
700
+ for i in self.opening_corners["down"]:
701
+ y = int(i[1])
702
+ x = int(i[0])
703
+ if y < self.height and x < self.width:
704
+ heatmaps[channel] = heatmaps[channel] + [(x, y)]
705
+ channel += 1
706
+
707
+ for i in self.icon_corners["upper_left"]:
708
+ y = int(i[1])
709
+ x = int(i[0])
710
+ if y < self.height and x < self.width:
711
+ heatmaps[channel] = heatmaps[channel] + [(x, y)]
712
+ channel += 1
713
+ for i in self.icon_corners["upper_right"]:
714
+ y = int(i[1])
715
+ x = int(i[0])
716
+ if y < self.height and x < self.width:
717
+ heatmaps[channel] = heatmaps[channel] + [(x, y)]
718
+ channel += 1
719
+ for i in self.icon_corners["lower_left"]:
720
+ y = int(i[1])
721
+ x = int(i[0])
722
+ if y < self.height and x < self.width:
723
+ heatmaps[channel] = heatmaps[channel] + [(x, y)]
724
+ channel += 1
725
+ for i in self.icon_corners["lower_right"]:
726
+ y = int(i[1])
727
+ x = int(i[0])
728
+ if y < self.height and x < self.width:
729
+ heatmaps[channel] = heatmaps[channel] + [(x, y)]
730
+
731
+ return heatmaps
732
+
733
+ def get_heatmaps(self):
734
+ heatmaps = np.zeros((21, self.height, self.width))
735
+ for p in self.points:
736
+ cord, _, p_type = p
737
+ x = int(np.round(cord[0]))
738
+ y = int(np.round(cord[1]))
739
+ channel = self.get_number(p_type)
740
+ if y < self.height and x < self.width:
741
+ heatmaps[channel - 1, y, x] = 1
742
+
743
+ channel = 13
744
+ for i in self.opening_corners["left"]:
745
+ y = int(i[1])
746
+ x = int(i[0])
747
+ if y < self.height and x < self.width:
748
+ heatmaps[channel, y, x] = 1
749
+ channel += 1
750
+ for i in self.opening_corners["right"]:
751
+ y = int(i[1])
752
+ x = int(i[0])
753
+ if y < self.height and x < self.width:
754
+ heatmaps[channel, y, x] = 1
755
+ channel += 1
756
+ for i in self.opening_corners["up"]:
757
+ y = int(i[1])
758
+ x = int(i[0])
759
+ if y < self.height and x < self.width:
760
+ heatmaps[channel, y, x] = 1
761
+ channel += 1
762
+ for i in self.opening_corners["down"]:
763
+ y = int(i[1])
764
+ x = int(i[0])
765
+ if y < self.height and x < self.width:
766
+ heatmaps[channel, y, x] = 1
767
+ channel += 1
768
+
769
+ for i in self.icon_corners["upper_left"]:
770
+ y = int(i[1])
771
+ x = int(i[0])
772
+ if y < self.height and x < self.width:
773
+ heatmaps[channel, y, x] = 1
774
+ channel += 1
775
+ for i in self.icon_corners["upper_right"]:
776
+ y = int(i[1])
777
+ x = int(i[0])
778
+ if y < self.height and x < self.width:
779
+ heatmaps[channel, y, x] = 1
780
+ channel += 1
781
+ for i in self.icon_corners["lower_left"]:
782
+ y = int(i[1])
783
+ x = int(i[0])
784
+ if y < self.height and x < self.width:
785
+ heatmaps[channel, y, x] = 1
786
+ channel += 1
787
+ for i in self.icon_corners["lower_right"]:
788
+ y = int(i[1])
789
+ x = int(i[0])
790
+ if y < self.height and x < self.width:
791
+ heatmaps[channel, y, x] = 1
792
+
793
+ kernel = get_gaussian2D(13)
794
+ for i, h in enumerate(heatmaps):
795
+ heatmaps[i] = cv2.filter2D(h, -1, kernel)
796
+
797
+ return heatmaps
798
+
799
+ def _clip_outside(self, rr, cc):
800
+ s = np.column_stack((rr, cc))
801
+ s = s[s[:, 0] < self.height]
802
+ s = s[s[:, 1] < self.width]
803
+
804
+ return s[:, 0], s[:, 1]
805
+
806
+ def lines_to_points(self, width, height, walls, lineWidth):
807
+ lines = [h.end_points for h in walls]
808
+
809
+ points = []
810
+ usedLinePointMask = []
811
+
812
+ for lineIndex, line in enumerate(lines):
813
+ usedLinePointMask.append([False, False])
814
+
815
+ for lineIndex_1, wall_1 in enumerate(walls):
816
+ line_1 = wall_1.end_points
817
+
818
+ lineDim_1 = self.get_lineDim(line_1, 1)
819
+ if lineDim_1 <= -1:
820
+ # If wall is diagonal we skip
821
+ continue
822
+
823
+ fixedValue_1 = (line_1[0][1 - lineDim_1] + line_1[1][1 - lineDim_1]) / 2
824
+ for lineIndex_2, wall_2 in enumerate(walls):
825
+ line_2 = wall_2.end_points
826
+
827
+ if lineIndex_2 <= lineIndex_1:
828
+ continue
829
+
830
+ lineDim_2 = self.get_lineDim(line_2, 1)
831
+ if lineDim_2 + lineDim_1 != 1:
832
+ # if walls have the same direction we skip
833
+ continue
834
+
835
+ fixedValue_2 = (line_2[0][1 - lineDim_2] + line_2[1][1 - lineDim_2]) / 2
836
+ lineWidth = max(wall_1.max_width, wall_2.max_width)
837
+ nearestPair, minDistance = self.findNearestJunctionPair(line_1, line_2, lineWidth)
838
+
839
+ if minDistance <= lineWidth:
840
+ pointIndex_1 = nearestPair[0]
841
+ pointIndex_2 = nearestPair[1]
842
+ if pointIndex_1 > -1 and pointIndex_2 > -1:
843
+ point = [None, None]
844
+ point[lineDim_1] = fixedValue_2
845
+ point[lineDim_2] = fixedValue_1
846
+ side = [None, None]
847
+ side[lineDim_1] = line_1[1 - pointIndex_1][lineDim_1] - fixedValue_2
848
+ side[lineDim_2] = line_2[1 - pointIndex_2][lineDim_2] - fixedValue_1
849
+
850
+ if side[0] < 0 and side[1] < 0:
851
+ points.append([point, point, ["point", 2, 1]])
852
+ elif side[0] > 0 and side[1] < 0:
853
+ points.append([point, point, ["point", 2, 2]])
854
+ elif side[0] > 0 and side[1] > 0:
855
+ points.append([point, point, ["point", 2, 3]])
856
+ elif side[0] < 0 and side[1] > 0:
857
+ points.append([point, point, ["point", 2, 4]])
858
+
859
+ usedLinePointMask[lineIndex_1][pointIndex_1] = True
860
+ usedLinePointMask[lineIndex_2][pointIndex_2] = True
861
+ elif (pointIndex_1 > -1 and pointIndex_2 == -1) or (pointIndex_1 == -1 and pointIndex_2 > -1):
862
+ if pointIndex_1 > -1:
863
+ lineDim = lineDim_1
864
+ pointIndex = pointIndex_1
865
+ fixedValue = fixedValue_2
866
+ pointValue = line_1[pointIndex_1][1 - lineDim_1]
867
+ usedLinePointMask[lineIndex_1][pointIndex_1] = True
868
+ else:
869
+ lineDim = lineDim_2
870
+ pointIndex = pointIndex_2
871
+ fixedValue = fixedValue_1
872
+ pointValue = line_2[pointIndex_2][1 - lineDim_2]
873
+ usedLinePointMask[lineIndex_2][pointIndex_2] = True
874
+
875
+ point = [None, None]
876
+ point[lineDim] = fixedValue
877
+ point[1 - lineDim] = pointValue
878
+
879
+ if pointIndex == 0:
880
+ if lineDim == 0:
881
+ points.append([point, point, ["point", 3, 4]])
882
+ else:
883
+ points.append([point, point, ["point", 3, 1]])
884
+ else:
885
+ if lineDim == 0:
886
+ points.append([point, point, ["point", 3, 2]])
887
+ else:
888
+ points.append([point, point, ["point", 3, 3]])
889
+
890
+ elif (
891
+ line_1[0][lineDim_1] < fixedValue_2
892
+ and line_1[1][lineDim_1] > fixedValue_2
893
+ and line_2[0][lineDim_2] < fixedValue_1
894
+ and line_2[1][lineDim_2] > fixedValue_1
895
+ ):
896
+ point = [None, None]
897
+ point[lineDim_1] = fixedValue_2
898
+ point[lineDim_2] = fixedValue_1
899
+ points.append([point, point, ["point", 4, 1]])
900
+
901
+ for lineIndex, pointMask in enumerate(usedLinePointMask):
902
+ lineDim = self.get_lineDim(lines[lineIndex], 1)
903
+ for pointIndex in range(2):
904
+ if pointMask[pointIndex] is True:
905
+ continue
906
+ point = [lines[lineIndex][pointIndex][0], lines[lineIndex][pointIndex][1]]
907
+ if pointIndex == 0:
908
+ if lineDim == 0:
909
+ points.append([point, point, ["point", 1, 4]])
910
+ elif lineDim == 1:
911
+ points.append([point, point, ["point", 1, 1]])
912
+ else:
913
+ if lineDim == 0:
914
+ points.append([point, point, ["point", 1, 2]])
915
+ elif lineDim == 1:
916
+ points.append([point, point, ["point", 1, 3]])
917
+
918
+ return points
919
+
920
+ def _pointId2index(self, g, t):
921
+ g_ = g - 1
922
+ t_ = t - 1
923
+ k = g_ * 4 + t_
924
+ return k
925
+
926
+ def _index2pointId(self, k):
927
+ g = k // 4 + 1
928
+ t = k % 4 + 1
929
+ return [g, t]
930
+
931
+ def _are_close(self, p1, p2, width):
932
+ return calc_distance(p1, p2) < width
933
+
934
+ def merge_joints(self, points, wall_width):
935
+ lookuptable = {}
936
+ lookuptable[0] = {0: 0, 1: 7, 2: None, 3: 6, 4: 9, 5: 11, 6: 6, 7: 7, 8: 8, 9: 9, 10: 12, 11: 11, 12: 12}
937
+ lookuptable[1] = {0: 7, 1: 1, 2: 4, 3: None, 4: 4, 5: 10, 6: 8, 7: 7, 8: 8, 9: 9, 10: 10, 11: 12, 12: 12}
938
+ lookuptable[2] = {0: None, 1: 4, 2: 2, 3: 5, 4: 4, 5: 5, 6: 11, 7: 9, 8: 12, 9: 9, 10: 10, 11: 11, 12: 12}
939
+ lookuptable[3] = {0: 6, 1: None, 2: 5, 3: 3, 4: 10, 5: 5, 6: 6, 7: 8, 8: 8, 9: 12, 10: 10, 11: 11, 12: 12}
940
+ lookuptable[4] = {0: 9, 1: 4, 2: 4, 3: 10, 4: 4, 5: 10, 6: 12, 7: 9, 8: 12, 9: 9, 10: 10, 11: 12, 12: 12}
941
+ lookuptable[5] = {0: 11, 1: 10, 2: 5, 3: 5, 4: 10, 5: 5, 6: 11, 7: 12, 8: 12, 9: 12, 10: 10, 11: 11, 12: 12}
942
+ lookuptable[6] = {0: 6, 1: 8, 2: 11, 3: 6, 4: 12, 5: 11, 6: 6, 7: 8, 8: 8, 9: 12, 10: 12, 11: 11, 12: 12}
943
+ lookuptable[7] = {0: 7, 1: 7, 2: 9, 3: 8, 4: 9, 5: 12, 6: 8, 7: 7, 8: 8, 9: 9, 10: 12, 11: 12, 12: 12}
944
+ lookuptable[8] = {0: 8, 1: 8, 2: 12, 3: 8, 4: 12, 5: 12, 6: 8, 7: 8, 8: 8, 9: 12, 10: 12, 11: 12, 12: 12}
945
+ lookuptable[9] = {0: 9, 1: 9, 2: 9, 3: 12, 4: 9, 5: 12, 6: 12, 7: 9, 8: 12, 9: 9, 10: 12, 11: 12, 12: 12}
946
+ lookuptable[10] = {
947
+ 0: 12,
948
+ 1: 10,
949
+ 2: 10,
950
+ 3: 10,
951
+ 4: 10,
952
+ 5: 10,
953
+ 6: 12,
954
+ 7: 12,
955
+ 8: 12,
956
+ 9: 12,
957
+ 10: 10,
958
+ 11: 12,
959
+ 12: 12,
960
+ }
961
+ lookuptable[11] = {
962
+ 0: 11,
963
+ 1: 12,
964
+ 2: 11,
965
+ 3: 11,
966
+ 4: 12,
967
+ 5: 11,
968
+ 6: 11,
969
+ 7: 12,
970
+ 8: 12,
971
+ 9: 12,
972
+ 10: 12,
973
+ 11: 11,
974
+ 12: 12,
975
+ }
976
+ lookuptable[12] = {
977
+ 0: 12,
978
+ 1: 12,
979
+ 2: 12,
980
+ 3: 12,
981
+ 4: 12,
982
+ 5: 12,
983
+ 6: 12,
984
+ 7: 12,
985
+ 8: 12,
986
+ 9: 12,
987
+ 10: 12,
988
+ 11: 12,
989
+ 12: 12,
990
+ }
991
+
992
+ newPoints = []
993
+ merged = [False] * len(points)
994
+ for i, point1 in enumerate(points):
995
+ if merged[i] is False:
996
+ pool = [point1]
997
+ for j, point2 in enumerate(points):
998
+ if j != i and merged[j] is False and self._are_close(point1[0], point2[0], wall_width):
999
+ merged[j] = True
1000
+ pool.append(point2)
1001
+
1002
+ if len(pool) == 1:
1003
+ newPoints.append(point1)
1004
+ merged[i] = True
1005
+ else:
1006
+ p_ = pool[0]
1007
+ for point_id in range(1, len(pool)):
1008
+ merge_to_p = pool[point_id]
1009
+
1010
+ k_ = self._pointId2index(p_[2][1], p_[2][2])
1011
+ k_merge_to_p = self._pointId2index(merge_to_p[2][1], merge_to_p[2][2])
1012
+
1013
+ knew = lookuptable[k_][k_merge_to_p]
1014
+ if knew is None:
1015
+ continue
1016
+
1017
+ typenew = self._index2pointId(knew)
1018
+ p_ = [p_[0], p_[1], ["point", typenew[0], typenew[1]]]
1019
+
1020
+ newPoints.append(p_)
1021
+
1022
+ return newPoints
1023
+
1024
+ def get_avg_wall_width(self):
1025
+ res = 0
1026
+ for i, w in enumerate(self.wall_objs):
1027
+ res += w.max_width
1028
+ res = res / float(i)
1029
+
1030
+ return res
1031
+
1032
+ def connect_walls(self, walls):
1033
+ new_walls = []
1034
+ num_walls = len(walls)
1035
+ remaining_walls = list(range(1, num_walls + 1))
1036
+
1037
+ # getting pillars
1038
+ remaining_pillar_ids = []
1039
+ for p_id in range(1, num_walls + 1):
1040
+ p_wall = self.find_wall_by_id(p_id, walls)
1041
+ if p_wall.wall_is_pillar(self.avg_wall_width):
1042
+ for wall_id in range(1, num_walls + 1):
1043
+ wall = self.find_wall_by_id(wall_id, walls)
1044
+ if p_wall.merge_possible(wall):
1045
+ break
1046
+ else:
1047
+ remaining_walls.pop(remaining_walls.index(p_wall.id))
1048
+ remaining_pillar_ids.append(p_wall.id)
1049
+
1050
+ while len(remaining_walls) > 0:
1051
+ new_wall_id = remaining_walls.pop(0)
1052
+ new_wall = self.find_wall_by_id(new_wall_id, walls)
1053
+
1054
+ found = True
1055
+ while found:
1056
+ found = False
1057
+ for merge_wall_id in remaining_walls:
1058
+ merged = self.find_wall_by_id(merge_wall_id, walls)
1059
+ temp_wall = new_wall.merge_walls(merged)
1060
+
1061
+ if temp_wall is not None:
1062
+ remaining_walls.pop(remaining_walls.index(merged.id))
1063
+ new_wall = temp_wall
1064
+ found = True
1065
+
1066
+ new_walls.append(new_wall)
1067
+
1068
+ # connect pillars to walls
1069
+ new_wall_id = num_walls + 1
1070
+ self.pillar_walls = []
1071
+ for id in remaining_pillar_ids:
1072
+ w = self.find_wall_by_id(id, walls)
1073
+ pws = w.split_pillar_wall(new_wall_id, self.avg_wall_width)
1074
+ new_wall_id += 4
1075
+ for pw in pws:
1076
+ self.pillar_walls.append(pw)
1077
+
1078
+ return new_walls
1079
+
1080
+ def get_number(self, x):
1081
+ return (x[1] - 1) * 4 + x[2]
1082
+
1083
+ def get_lineDim(self, line, lineWidth):
1084
+ lineWidth = lineWidth or 1
1085
+ if abs(line[0][0] - line[1][0]) > abs(line[0][1] - line[1][1]) and abs(line[0][1] - line[1][1]) <= lineWidth:
1086
+ return 0
1087
+ elif abs(line[0][1] - line[1][1]) > abs(line[0][0] - line[1][0]) and abs(line[0][0] - line[1][0]) <= lineWidth:
1088
+ return 1
1089
+ else:
1090
+ return -1
1091
+
1092
+ def findNearestJunctionPair(self, line_1, line_2, gap):
1093
+
1094
+ minDistance = None
1095
+ for index_1 in range(0, 2):
1096
+ for index_2 in range(0, 2):
1097
+ distance = calc_distance(line_1[index_1], line_2[index_2])
1098
+ if minDistance is None or distance < minDistance:
1099
+ nearestPair = [index_1, index_2]
1100
+ minDistance = distance
1101
+
1102
+ if minDistance > gap:
1103
+ lineDim_1 = self.get_lineDim(line_1, 1)
1104
+ lineDim_2 = self.get_lineDim(line_2, 1)
1105
+
1106
+ if lineDim_1 + lineDim_2 == 1:
1107
+ fixedValue_1 = (line_1[0][1 - lineDim_1] + line_1[1][1 - lineDim_1]) / 2
1108
+ fixedValue_2 = (line_2[0][1 - lineDim_2] + line_2[1][1 - lineDim_2]) / 2
1109
+
1110
+ if line_2[0][lineDim_2] < fixedValue_1 and line_2[1][lineDim_2] > fixedValue_1:
1111
+ for index in range(2):
1112
+ distance = abs(line_1[index][lineDim_1] - fixedValue_2)
1113
+ if distance < minDistance:
1114
+ nearestPair = [index, -1]
1115
+ minDistance = distance
1116
+
1117
+ if line_1[0][lineDim_1] < fixedValue_2 and line_1[1][lineDim_1] > fixedValue_2:
1118
+ for index in range(2):
1119
+ distance = abs(line_2[index][lineDim_2] - fixedValue_1)
1120
+ if distance < minDistance:
1121
+ nearestPair = [-1, index]
1122
+ minDistance = distance
1123
+
1124
+ return nearestPair, minDistance
1125
+
1126
+ def find_wall_by_id(self, id, walls):
1127
+ for wall in walls:
1128
+ if wall.id == id:
1129
+ return wall
1130
+
1131
+ return None
data_preprocess/cubicasa5k/loaders.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import lmdb
2
+ import pickle
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ from house import House
8
+ from numpy import genfromtxt
9
+ from torch.utils.data import Dataset
10
+
11
+ ROOM_NAMES = {
12
+ 0: "Background",
13
+ 1: "Outdoor",
14
+ 2: "Wall",
15
+ 3: "Kitchen",
16
+ 4: "Living Room",
17
+ 5: "Bed Room",
18
+ 6: "Bath",
19
+ 7: "Entry",
20
+ 8: "Railing",
21
+ 9: "Storage",
22
+ 10: "Garage",
23
+ 11: "Undefined",
24
+ }
25
+
26
+ ICON_NAMES = {
27
+ 0: "No Icon",
28
+ 1: "Window",
29
+ 2: "Door",
30
+ 3: "Closet",
31
+ 4: "Electrical Applience",
32
+ 5: "Toilet",
33
+ 6: "Sink",
34
+ 7: "Sauna Bench",
35
+ 8: "Fire Place",
36
+ 9: "Bathtub",
37
+ 10: "Chimney",
38
+ }
39
+
40
+
41
+ class FloorplanSVG(Dataset):
42
+ def __init__(
43
+ self,
44
+ data_folder,
45
+ data_file,
46
+ is_transform=True,
47
+ augmentations=None,
48
+ img_norm=True,
49
+ format="txt",
50
+ original_size=False,
51
+ lmdb_folder="cubi_lmdb/",
52
+ ):
53
+ self.img_norm = img_norm
54
+ self.is_transform = is_transform
55
+ self.augmentations = augmentations
56
+ self.get_data = None
57
+ self.original_size = original_size
58
+ self.image_file_name = "/F1_scaled.png"
59
+ self.org_image_file_name = "/F1_original.png"
60
+ self.svg_file_name = "/model.svg"
61
+
62
+ if format == "txt":
63
+ self.get_data = self.get_txt
64
+ # if format == 'lmdb':
65
+ # self.lmdb = lmdb.open(data_folder+lmdb_folder, readonly=True,
66
+ # max_readers=8, lock=False,
67
+ # readahead=True, meminit=False)
68
+ # self.get_data = self.get_lmdb
69
+ # self.is_transform = False
70
+
71
+ self.data_folder = data_folder
72
+ # Load txt file to list
73
+ self.folders = genfromtxt(data_folder + data_file, dtype="str")
74
+
75
+ def __len__(self):
76
+ """__len__"""
77
+ return len(self.folders)
78
+
79
+ def __getitem__(self, index):
80
+ sample = self.get_data(index)
81
+
82
+ if self.augmentations is not None:
83
+ sample = self.augmentations(sample)
84
+
85
+ if self.is_transform:
86
+ sample = self.transform(sample)
87
+
88
+ return sample
89
+
90
+ def get_txt(self, index):
91
+ fplan = cv2.imread(self.data_folder + self.folders[index] + self.image_file_name)
92
+ fplan = cv2.cvtColor(fplan, cv2.COLOR_BGR2RGB) # correct color channels
93
+ height, width, nchannel = fplan.shape
94
+ fplan = np.moveaxis(fplan, -1, 0)
95
+
96
+ # Getting labels for segmentation and heatmaps
97
+ house = House(self.data_folder + self.folders[index] + self.svg_file_name, height, width)
98
+ # Combining them to one numpy tensor
99
+ label = torch.tensor(house.get_segmentation_tensor().astype(np.float32))
100
+ heatmaps = house.get_heatmap_dict()
101
+ room_polygons, room_types, icon_polygons, icon_types = house.get_coords_and_labels()
102
+ coef_width = 1
103
+ if self.original_size:
104
+ fplan = cv2.imread(self.data_folder + self.folders[index] + self.org_image_file_name)
105
+ fplan = cv2.cvtColor(fplan, cv2.COLOR_BGR2RGB) # correct color channels
106
+ height_org, width_org, nchannel = fplan.shape
107
+ fplan = np.moveaxis(fplan, -1, 0)
108
+ label = label.unsqueeze(0)
109
+ label = torch.nn.functional.interpolate(label, size=(height_org, width_org), mode="nearest")
110
+ label = label.squeeze(0)
111
+
112
+ coef_height = float(height_org) / float(height)
113
+ coef_width = float(width_org) / float(width)
114
+ for key, value in heatmaps.items():
115
+ heatmaps[key] = [(int(round(x * coef_width)), int(round(y * coef_height))) for x, y in value]
116
+
117
+ new_room_polygons = []
118
+ for poly in room_polygons:
119
+ new_room_polygons.append([(int(round(x * coef_width)), int(round(y * coef_height))) for x, y in poly])
120
+ room_polygons = new_room_polygons
121
+
122
+ new_icon_polygons = []
123
+ for poly in icon_polygons:
124
+ new_icon_polygons.append([(int(round(x * coef_width)), int(round(y * coef_height))) for x, y in poly])
125
+ icon_polygons = new_icon_polygons
126
+
127
+ img = torch.tensor(fplan.astype(np.float32))
128
+
129
+ sample = {
130
+ "image": img,
131
+ "label": label,
132
+ "folder": self.folders[index],
133
+ "heatmaps": heatmaps,
134
+ "scale": coef_width,
135
+ "room_polygon": room_polygons,
136
+ "room_type": room_types,
137
+ "icon_polygon": icon_polygons,
138
+ "icon_type": icon_types,
139
+ }
140
+
141
+ return sample
142
+
143
+ def get_lmdb(self, index):
144
+ key = self.folders[index].encode()
145
+ with self.lmdb.begin(write=False) as f:
146
+ data = f.get(key)
147
+
148
+ sample = pickle.loads(data)
149
+ return sample
150
+
151
+ def transform(self, sample):
152
+ fplan = sample["image"]
153
+ # Normalization values to range -1 and 1
154
+ fplan = 2 * (fplan / 255.0) - 1
155
+
156
+ sample["image"] = fplan
157
+
158
+ return sample
data_preprocess/cubicasa5k/plotting.py ADDED
@@ -0,0 +1,820 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.path as mplp
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ from matplotlib import cm, colors
5
+ from shapely.geometry import Point, Polygon
6
+ from skimage import draw
7
+
8
+
9
+ def discrete_cmap_furukawa():
10
+ """create a colormap with N (N<15) discrete colors and register it"""
11
+ # define individual colors as hex values
12
+ cpool = [
13
+ "#696969",
14
+ "#b3de69",
15
+ "#ffffb3",
16
+ "#8dd3c7",
17
+ "#fdb462",
18
+ "#fccde5",
19
+ "#80b1d3",
20
+ "#d9d9d9",
21
+ "#fb8072",
22
+ "#577a4d",
23
+ "white",
24
+ "#000000",
25
+ "#e31a1c",
26
+ ]
27
+ cmap3 = colors.ListedColormap(cpool, "rooms_furukawa")
28
+ cm.register_cmap(cmap=cmap3)
29
+
30
+ cpool = [
31
+ "#ede676",
32
+ "#8dd3c7",
33
+ "#b15928",
34
+ "#fdb462",
35
+ "#ffff99",
36
+ "#fccde5",
37
+ "#80b1d3",
38
+ "#d9d9d9",
39
+ "#fb8072",
40
+ "#696969",
41
+ "#577a4d",
42
+ "#e31a1c",
43
+ "#42ef59",
44
+ "#8c595a",
45
+ "#3131e5",
46
+ "#48e0e6",
47
+ "white",
48
+ ]
49
+ cmap3 = colors.ListedColormap(cpool, "icons_furukawa")
50
+ cm.register_cmap(cmap=cmap3)
51
+
52
+
53
+ def drawJunction(h, point, point_type, width, height):
54
+ lineLength = 15
55
+ lineWidth = 10
56
+ x, y = point
57
+ # plt.text(x,y,str(index),fontsize=25,color='r')
58
+ if point_type == -1:
59
+ h.scatter(x, y, color="#6488ea")
60
+ ###########################
61
+ # o
62
+ # | #6488ea soft blue
63
+ # | drawcode = [1,1]
64
+ #
65
+ ###########################
66
+ if point_type == 0:
67
+ h.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="#6488ea")
68
+ # plt.scatter(x, y-10, c='k')
69
+ ###########################
70
+ #
71
+ # ---o #6241c7 bluey purple
72
+ # drawcode = [1,2]
73
+ #
74
+ ###########################
75
+ elif point_type == 1:
76
+ h.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="#6241c7")
77
+ # plt.scatter(x+10, y, c='k')
78
+ ###########################
79
+ # |
80
+ # | drawcode = [1,3]
81
+ # o #056eee cerulean blue
82
+ #
83
+ ###########################
84
+ elif point_type == 2:
85
+ h.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="#056eee")
86
+ # plt.scatter(x, y+10, c='k')
87
+ ###########################
88
+ #
89
+ # drawcode = [1,4]
90
+ #
91
+ # o--- #004577 prussian blue
92
+ #
93
+ ###########################
94
+ elif point_type == 3:
95
+ h.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="#004577")
96
+ # plt.scatter(x-10, y, c='k')
97
+ ###########################
98
+ #
99
+ # |--- drawcode = [2,3]
100
+ # |
101
+ #
102
+ ###########################
103
+ elif point_type == 6:
104
+ h.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="#04d8b2")
105
+ h.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="#04d8b2")
106
+ ###########################
107
+ #
108
+ # ---|
109
+ # | drawcode = [2,4]
110
+ #
111
+ ###########################
112
+ elif point_type == 7:
113
+ h.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="#cdfd02")
114
+ h.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="#cdfd02")
115
+ ###########################
116
+ # |
117
+ # ---| drawcode = [2,1]
118
+ #
119
+ #
120
+ ###########################
121
+ elif point_type == 4:
122
+ h.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="#ff81c0")
123
+ h.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="#ff81c0")
124
+ ###########################
125
+ #
126
+ # |
127
+ # | drawcode = [2,2]
128
+ # --
129
+ #
130
+ ###########################
131
+ elif point_type == 5:
132
+ h.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="#f97306")
133
+ h.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="#f97306")
134
+ ###########################
135
+ #
136
+ # |
137
+ # |--- drawcode = [3,4]
138
+ # |
139
+ #
140
+ ###########################
141
+ elif point_type == 11:
142
+ h.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="b")
143
+ h.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="b")
144
+ h.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="b")
145
+ ###########################
146
+ #
147
+ # ---
148
+ # | drawcode = [3,1]
149
+ # |
150
+ #
151
+ ###########################
152
+ elif point_type == 8:
153
+ h.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="y")
154
+ h.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="y")
155
+ h.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="y")
156
+ ###########################
157
+ #
158
+ # |
159
+ # ---| drawcode = [3,2]
160
+ # |
161
+ #
162
+ ###########################
163
+ elif point_type == 9:
164
+ h.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="r")
165
+ h.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="r")
166
+ h.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="r")
167
+ ###########################
168
+ #
169
+ # |
170
+ # | drawcode = [3,3]
171
+ # ---
172
+ #
173
+ ###########################
174
+ elif point_type == 10:
175
+ h.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="m")
176
+ h.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="m")
177
+ h.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="m")
178
+ ###########################
179
+ #
180
+ # |
181
+ # --- drawcode = [4,1]
182
+ # |
183
+ #
184
+ ###########################
185
+ elif point_type == 12:
186
+ h.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="k")
187
+ h.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="k")
188
+ h.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="k")
189
+ h.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="k")
190
+
191
+ lineLength = 10
192
+ lineWidth = 5
193
+
194
+ ###########################
195
+ # o--- opening left
196
+ ###########################
197
+ if point_type == 13:
198
+ h.plot([x], [y], "o", markersize=30, color="red")
199
+ h.plot([x], [y], "o", markersize=25, color="white")
200
+ h.text(x, y, "OL", fontsize=30, color="magenta")
201
+ ###########################
202
+ # ---o opening right
203
+ ###########################
204
+ elif point_type == 14:
205
+ h.plot([x], [y], "o", markersize=30, color="red")
206
+ h.plot([x], [y], "o", markersize=25, color="white")
207
+ h.text(x, y, "OR", fontsize=30, color="magenta")
208
+ ###########################
209
+ # o opening up
210
+ # |
211
+ # |
212
+ ###########################
213
+ elif point_type == 15:
214
+ h.plot([x], [y], "o", markersize=30, color="red")
215
+ h.plot([x], [y], "o", markersize=25, color="white")
216
+ h.text(x, y, "OU", fontsize=30, color="mediumblue")
217
+ ###########################
218
+ # | opening down
219
+ # |
220
+ # o
221
+ ###########################
222
+ elif point_type == 16:
223
+ h.plot([x], [y], "o", markersize=30, color="red")
224
+ h.plot([x], [y], "o", markersize=25, color="white")
225
+ h.text(x, y, "OD", fontsize=30, color="mediumblue")
226
+
227
+ ###########################
228
+ #
229
+ # |--- drawcode = [2,3]
230
+ # |
231
+ #
232
+ ###########################
233
+ elif point_type == 17:
234
+ h.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="indianred")
235
+ h.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="indianred")
236
+ ###########################
237
+ #
238
+ # ---|
239
+ # | drawcode = [2,4]
240
+ #
241
+ ###########################
242
+ elif point_type == 18:
243
+ h.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="darkred")
244
+ h.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="darkred")
245
+ ###########################
246
+ #
247
+ # |
248
+ # | drawcode = [2,2]
249
+ # --
250
+ #
251
+ ###########################
252
+ elif point_type == 19:
253
+ h.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="salmon")
254
+ h.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="salmon")
255
+ ###########################
256
+ # |
257
+ # ---| drawcode = [2,1]
258
+ #
259
+ #
260
+ ###########################
261
+ elif point_type == 20:
262
+ h.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="orangered")
263
+ h.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="orangered")
264
+
265
+
266
+ def draw_junction_from_dict(point_dict, width, height, size=1, fontsize=30):
267
+ index = 0
268
+ markersize_large = 20 * size
269
+ markersize_small = 15 * size
270
+ for point_type, locations in point_dict.items():
271
+ for loc in locations:
272
+ x, y = loc
273
+ lineLength = 20 * size
274
+ lineWidth = 20 * size
275
+ # plt.text(x,y,str(index),fontsize=25,color='r')
276
+ ###########################
277
+ # o
278
+ # | #6488ea soft blue
279
+ # | drawcode = [1,1]
280
+ #
281
+ ###########################
282
+ if point_type == 0:
283
+ plt.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="#6488ea")
284
+ # plt.scatter(x, y-10, c='k')
285
+ ###########################
286
+ #
287
+ # ---o #6241c7 bluey purple
288
+ # drawcode = [1,2]
289
+ #
290
+ ###########################
291
+ elif point_type == 1:
292
+ plt.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="#6241c7")
293
+ # plt.scatter(x+10, y, c='k')
294
+ ###########################
295
+ # |
296
+ # | drawcode = [1,3]
297
+ # o #056eee cerulean blue
298
+ #
299
+ ###########################
300
+ elif point_type == 2:
301
+ plt.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="#056eee")
302
+ # plt.scatter(x, y+10, c='k')
303
+ ###########################
304
+ #
305
+ # drawcode = [1,4]
306
+ #
307
+ # o--- #004577 prussian blue
308
+ #
309
+ ###########################
310
+ elif point_type == 3:
311
+ plt.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="#004577")
312
+ # plt.scatter(x-10, y, c='k')
313
+ ###########################
314
+ #
315
+ # |--- drawcode = [2,3]
316
+ # |
317
+ #
318
+ ###########################
319
+ elif point_type == 6:
320
+ plt.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="#04d8b2")
321
+ plt.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="#04d8b2")
322
+ ###########################
323
+ #
324
+ # ---|
325
+ # | drawcode = [2,4]
326
+ #
327
+ ###########################
328
+ elif point_type == 7:
329
+ plt.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="#cdfd02")
330
+ plt.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="#cdfd02")
331
+ ###########################
332
+ # |
333
+ # ---| drawcode = [2,1]
334
+ #
335
+ #
336
+ ###########################
337
+ elif point_type == 4:
338
+ plt.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="#ff81c0")
339
+ plt.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="#ff81c0")
340
+ ###########################
341
+ #
342
+ # |
343
+ # | drawcode = [2,2]
344
+ # --
345
+ #
346
+ ###########################
347
+ elif point_type == 5:
348
+ plt.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="#f97306")
349
+ plt.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="#f97306")
350
+ ###########################
351
+ #
352
+ # |
353
+ # |--- drawcode = [3,4]
354
+ # |
355
+ #
356
+ ###########################
357
+ elif point_type == 11:
358
+ plt.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="b")
359
+ plt.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="b")
360
+ plt.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="b")
361
+ ###########################
362
+ #
363
+ # ---
364
+ # | drawcode = [3,1]
365
+ # |
366
+ #
367
+ ###########################
368
+ elif point_type == 8:
369
+ plt.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="y")
370
+ plt.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="y")
371
+ plt.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="y")
372
+ ###########################
373
+ #
374
+ # |
375
+ # ---| drawcode = [3,2]
376
+ # |
377
+ #
378
+ ###########################
379
+ elif point_type == 9:
380
+ plt.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="r")
381
+ plt.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="r")
382
+ plt.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="r")
383
+ ###########################
384
+ #
385
+ # |
386
+ # | drawcode = [3,3]
387
+ # ---
388
+ #
389
+ ###########################
390
+ elif point_type == 10:
391
+ plt.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="m")
392
+ plt.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="m")
393
+ plt.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="m")
394
+ ###########################
395
+ #
396
+ # |
397
+ # --- drawcode = [4,1]
398
+ # |
399
+ #
400
+ ###########################
401
+ elif point_type == 12:
402
+ plt.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="k")
403
+ plt.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="k")
404
+ plt.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="k")
405
+ plt.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="k")
406
+
407
+ lineLength = 15 * size
408
+ lineWidth = 15 * size
409
+
410
+ ###########################
411
+ # o--- opening left
412
+ ###########################
413
+ if point_type == 13:
414
+ plt.plot([x], [y], "o", markersize=markersize_large, color="red")
415
+ plt.plot([x], [y], "o", markersize=markersize_small, color="white")
416
+ plt.text(x, y, "OL", fontsize=fontsize, color="magenta")
417
+ ###########################
418
+ # ---o opening right
419
+ ###########################
420
+ elif point_type == 14:
421
+ plt.plot([x], [y], "o", markersize=markersize_large, color="red")
422
+ plt.plot([x], [y], "o", markersize=markersize_small, color="white")
423
+ plt.text(x, y, "OR", fontsize=fontsize, color="magenta")
424
+ ###########################
425
+ # o opening up
426
+ # |
427
+ # |
428
+ ###########################
429
+ elif point_type == 15:
430
+ plt.plot([x], [y], "o", markersize=markersize_large, color="red")
431
+ plt.plot([x], [y], "o", markersize=markersize_small, color="white")
432
+ plt.text(x, y, "OU", fontsize=fontsize, color="mediumblue")
433
+ ###########################
434
+ # | opening down
435
+ # |
436
+ # o
437
+ ###########################
438
+ elif point_type == 16:
439
+ plt.plot([x], [y], "o", markersize=markersize_large, color="red")
440
+ plt.plot([x], [y], "o", markersize=markersize_small, color="white")
441
+ plt.text(x, y, "OD", fontsize=fontsize, color="mediumblue")
442
+
443
+ ###########################
444
+ #
445
+ # |--- drawcode = [2,3]
446
+ # |
447
+ #
448
+ ###########################
449
+ elif point_type == 17:
450
+ plt.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="indianred")
451
+ plt.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="indianred")
452
+ ###########################
453
+ #
454
+ # ---|
455
+ # | drawcode = [2,4]
456
+ #
457
+ ###########################
458
+ elif point_type == 18:
459
+ plt.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="darkred")
460
+ plt.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="darkred")
461
+ ###########################
462
+ #
463
+ # |
464
+ # | drawcode = [2,2]
465
+ # --
466
+ #
467
+ ###########################
468
+ elif point_type == 19:
469
+ plt.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="salmon")
470
+ plt.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="salmon")
471
+ ###########################
472
+ # |
473
+ # ---| drawcode = [2,1]
474
+ #
475
+ #
476
+ ###########################
477
+ elif point_type == 20:
478
+ plt.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="orangered")
479
+ plt.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="orangered")
480
+
481
+ index += 1
482
+
483
+
484
+ def plot_pre_rec_4(instances, classes):
485
+ walls = ["Wall", "Railing"]
486
+ openings = ["Window", "Door"]
487
+ rooms = [
488
+ "Outdoor",
489
+ "Kitchen",
490
+ "Living Room",
491
+ "Bed Room",
492
+ "Entry",
493
+ "Dining",
494
+ "Storage",
495
+ "Garage",
496
+ "Undefined Room",
497
+ "Sauna",
498
+ "Fire Place",
499
+ "Bathtub",
500
+ "Chimney",
501
+ ]
502
+ icons = [
503
+ "Bath",
504
+ "Closet",
505
+ "Electrical Appliance",
506
+ "Toilet",
507
+ "Shower",
508
+ "Sink",
509
+ "Sauna",
510
+ "Fire Place",
511
+ "Bathtub",
512
+ "Chimney",
513
+ ]
514
+
515
+ def make_sub_plot(classes_to_plot):
516
+ plt.ylim([0.0, 1.0])
517
+ plt.xlim([0.0, 1.0])
518
+ plt.xlabel("Recall")
519
+ plt.ylabel("Precision")
520
+ indx = [classes.index(i) for i in classes_to_plot]
521
+ ins = instances[:, indx].sum(axis=1)
522
+
523
+ correct = ins[:, 0]
524
+ false_positive = ins[:, 2]
525
+ false_negatives = ins[:, 1]
526
+ precision = correct / (correct + false_positive)
527
+ recall = correct / (correct + false_negatives)
528
+
529
+ plt.step(recall[::-1], precision, color="b", alpha=0.2, where="post")
530
+ plt.fill_between(recall[::-1], precision, step="post", alpha=0.2, color="b")
531
+
532
+ plt.subplot(2, 2, 1)
533
+ plt.title("Walls")
534
+ make_sub_plot(walls)
535
+ plt.subplot(2, 2, 2)
536
+ plt.title("Openings")
537
+ make_sub_plot(openings)
538
+ plt.subplot(2, 2, 3)
539
+ plt.title("Rooms")
540
+ make_sub_plot(rooms)
541
+ plt.subplot(2, 2, 4)
542
+ plt.title("Icons")
543
+ make_sub_plot(icons)
544
+
545
+
546
+ def discrete_cmap():
547
+ """create a colormap with N (N<15) discrete colors and register it"""
548
+ # define individual colors as hex values
549
+ cpool = [
550
+ "#DCDCDC",
551
+ "#b3de69",
552
+ "#000000",
553
+ "#8dd3c7",
554
+ "#fdb462",
555
+ "#fccde5",
556
+ "#80b1d3",
557
+ "#808080",
558
+ "#fb8072",
559
+ "#696969",
560
+ "#577a4d",
561
+ "#ffffb3",
562
+ ]
563
+ cmap3 = colors.ListedColormap(cpool, "rooms")
564
+ cm.register_cmap(cmap=cmap3)
565
+
566
+ cpool = [
567
+ "#DCDCDC",
568
+ "#8dd3c7",
569
+ "#b15928",
570
+ "#fdb462",
571
+ "#ffff99",
572
+ "#fccde5",
573
+ "#80b1d3",
574
+ "#808080",
575
+ "#fb8072",
576
+ "#696969",
577
+ "#577a4d",
578
+ ]
579
+ cmap3 = colors.ListedColormap(cpool, "icons")
580
+ cm.register_cmap(cmap=cmap3)
581
+
582
+ """create a colormap with N (N<15) discrete colors and register it"""
583
+ # define individual colors as hex values
584
+ cpool = [
585
+ "#DCDCDC",
586
+ "#b3de69",
587
+ "#000000",
588
+ "#8dd3c7",
589
+ "#fdb462",
590
+ "#fccde5",
591
+ "#80b1d3",
592
+ "#808080",
593
+ "#fb8072",
594
+ "#696969",
595
+ "#577a4d",
596
+ "#ffffb3",
597
+ "d3d5d7",
598
+ ]
599
+ cmap3 = colors.ListedColormap(cpool, "rooms_furu")
600
+ cm.register_cmap(cmap=cmap3)
601
+
602
+ cpool = [
603
+ "#DCDCDC",
604
+ "#8dd3c7",
605
+ "#b15928",
606
+ "#fdb462",
607
+ "#ffff99",
608
+ "#fccde5",
609
+ "#80b1d3",
610
+ "#808080",
611
+ "#fb8072",
612
+ "#696969",
613
+ "#577a4d",
614
+ ]
615
+ cmap3 = colors.ListedColormap(cpool, "rooms_furu")
616
+ cm.register_cmap(cmap=cmap3)
617
+
618
+
619
+ def segmentation_plot(rooms_pred, icons_pred, rooms_label, icons_label):
620
+ room_classes = [
621
+ "Background",
622
+ "Outdoor",
623
+ "Wall",
624
+ "Kitchen",
625
+ "Living Room",
626
+ "Bed Room",
627
+ "Bath",
628
+ "Entry",
629
+ "Railing",
630
+ "Storage",
631
+ "Garage",
632
+ "Undefined",
633
+ ]
634
+ icon_classes = [
635
+ "No Icon",
636
+ "Window",
637
+ "Door",
638
+ "Closet",
639
+ "Electrical Applience",
640
+ "Toilet",
641
+ "Sink",
642
+ "Sauna Bench",
643
+ "Fire Place",
644
+ "Bathtub",
645
+ "Chimney",
646
+ ]
647
+ discrete_cmap() # custom colormap
648
+
649
+ fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(30, 15))
650
+ axes[0].set_title("Room Ground Truth")
651
+ axes[0].imshow(rooms_label, cmap="rooms", vmin=0, vmax=len(room_classes) - 1)
652
+
653
+ axes[1].set_title("Room Prediction")
654
+ im = axes[1].imshow(rooms_pred, cmap="rooms", vmin=0, vmax=len(room_classes) - 1)
655
+
656
+ cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
657
+ cbar = fig.colorbar(im, cax=cbar_ax, ticks=np.arange(12) + 0.5)
658
+
659
+ fig.subplots_adjust(right=0.8)
660
+ cbar.ax.set_yticklabels(room_classes)
661
+ plt.show()
662
+
663
+ fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(30, 15))
664
+ axes[0].set_title("Icon Ground Truth")
665
+ axes[0].imshow(icons_label, cmap="icons", vmin=0, vmax=len(icon_classes) - 1)
666
+
667
+ axes[1].set_title("Icon Prediction")
668
+ im = axes[1].imshow(icons_pred, cmap="icons", vmin=0, vmax=len(icon_classes) - 1)
669
+
670
+ cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
671
+ cbar = fig.colorbar(im, cax=cbar_ax, ticks=np.arange(11) + 0.5)
672
+
673
+ fig.subplots_adjust(right=0.8)
674
+ cbar.ax.set_yticklabels(icon_classes)
675
+ plt.show()
676
+
677
+
678
+ def polygons_to_image(polygons, types, room_polygons, room_types, height, width):
679
+ pol_room_seg = np.zeros((height, width))
680
+ pol_icon_seg = np.zeros((height, width))
681
+
682
+ for i, pol in enumerate(room_polygons):
683
+ mask = shp_mask(pol, np.arange(width), np.arange(height))
684
+
685
+ # jj, ii = draw.polygon(pol[:, 1], pol[:, 0])
686
+ pol_room_seg[mask] = room_types[i]["class"]
687
+
688
+ for i, pol in enumerate(polygons):
689
+ jj, ii = draw.polygon(pol[:, 1], pol[:, 0])
690
+ if types[i]["type"] == "wall":
691
+ pol_room_seg[jj, ii] = types[i]["class"]
692
+ else:
693
+ pol_icon_seg[jj, ii] = types[i]["class"]
694
+
695
+ return pol_room_seg, pol_icon_seg
696
+
697
+
698
+ def plot_room(r, name, n_classes=12):
699
+ discrete_cmap() # custom colormap
700
+ plt.figure(figsize=(40, 30))
701
+ plt.axis("off")
702
+ plt.tight_layout()
703
+ plt.imshow(r, cmap="rooms", vmin=0, vmax=n_classes - 1)
704
+ plt.savefig(name + ".png", format="png")
705
+ plt.show()
706
+
707
+
708
+ def plot_icon(i, name, n_classes=11):
709
+ discrete_cmap() # custom colormap
710
+ plt.figure(figsize=(40, 30))
711
+ plt.axis("off")
712
+ plt.tight_layout()
713
+ plt.imshow(i, cmap="icons", vmin=0, vmax=n_classes - 1)
714
+ plt.savefig(name + ".png", format="png")
715
+ plt.show()
716
+
717
+
718
+ def plot_heatmaps(h, name):
719
+ for index, i in enumerate(h):
720
+ plt.figure(figsize=(40, 30))
721
+ plt.axis("off")
722
+ plt.tight_layout()
723
+ plt.imshow(i, cmap="Reds", vmin=0, vmax=1)
724
+ plt.savefig(name + str(index) + ".png", format="png")
725
+ plt.show()
726
+
727
+
728
+ def outline_to_mask(line, x, y):
729
+ """Create mask from outline contour
730
+
731
+ Parameters
732
+ ----------
733
+ line: array-like (N, 2)
734
+ x, y: 1-D grid coordinates (input for meshgrid)
735
+
736
+ Returns
737
+ -------
738
+ mask : 2-D boolean array (True inside)
739
+
740
+ Examples
741
+ --------
742
+ >>> from shapely.geometry import Point
743
+ >>> poly = Point(0,0).buffer(1)
744
+ >>> x = np.linspace(-5,5,100)
745
+ >>> y = np.linspace(-5,5,100)
746
+ >>> mask = outline_to_mask(poly.boundary, x, y)
747
+ """
748
+ mpath = mplp.Path(line)
749
+ X, Y = np.meshgrid(x, y)
750
+ points = np.array((X.flatten(), Y.flatten())).T
751
+ mask = mpath.contains_points(points).reshape(X.shape)
752
+ return mask
753
+
754
+
755
+ def _grid_bbox(x, y):
756
+ dx = dy = 0
757
+ return x[0] - dx / 2, x[-1] + dx / 2, y[0] - dy / 2, y[-1] + dy / 2
758
+
759
+
760
+ def _bbox_to_rect(bbox):
761
+ l, r, b, t = bbox
762
+ return Polygon([(l, b), (r, b), (r, t), (l, t)])
763
+
764
+
765
+ def shp_mask(shp, x, y, m=None):
766
+ """
767
+ Adapted from code written by perrette
768
+ form: https://gist.github.com/perrette/a78f99b76aed54b6babf3597e0b331f8
769
+ Use recursive sub-division of space and shapely contains method to create a raster mask on a regular grid.
770
+
771
+ Parameters
772
+ ----------
773
+ shp : shapely's Polygon (or whatever with a "contains" method and intersects method)
774
+ x, y : 1-D numpy arrays defining a regular grid
775
+ m : mask to fill, optional (will be created otherwise)
776
+
777
+ Returns
778
+ -------
779
+ m : boolean 2-D array, True inside shape.
780
+
781
+ Examples
782
+ --------
783
+ >>> from shapely.geometry import Point
784
+ >>> poly = Point(0,0).buffer(1)
785
+ >>> x = np.linspace(-5,5,100)
786
+ >>> y = np.linspace(-5,5,100)
787
+ >>> mask = shp_mask(poly, x, y)
788
+ """
789
+ rect = _bbox_to_rect(_grid_bbox(x, y))
790
+
791
+ if m is None:
792
+ m = np.zeros((y.size, x.size), dtype=bool)
793
+
794
+ if not shp.intersects(rect):
795
+ m[:] = False
796
+
797
+ elif shp.contains(rect):
798
+ m[:] = True
799
+
800
+ else:
801
+ k, l = m.shape
802
+
803
+ if k == 1 and l == 1:
804
+ m[:] = shp.contains(Point(x[0], y[0]))
805
+
806
+ elif k == 1:
807
+ m[:, : l // 2] = shp_mask(shp, x[: l // 2], y, m[:, : l // 2])
808
+ m[:, l // 2 :] = shp_mask(shp, x[l // 2 :], y, m[:, l // 2 :])
809
+
810
+ elif l == 1:
811
+ m[: k // 2] = shp_mask(shp, x, y[: k // 2], m[: k // 2])
812
+ m[k // 2 :] = shp_mask(shp, x, y[k // 2 :], m[k // 2 :])
813
+
814
+ else:
815
+ m[: k // 2, : l // 2] = shp_mask(shp, x[: l // 2], y[: k // 2], m[: k // 2, : l // 2])
816
+ m[: k // 2, l // 2 :] = shp_mask(shp, x[l // 2 :], y[: k // 2], m[: k // 2, l // 2 :])
817
+ m[k // 2 :, : l // 2] = shp_mask(shp, x[: l // 2], y[k // 2 :], m[k // 2 :, : l // 2])
818
+ m[k // 2 :, l // 2 :] = shp_mask(shp, x[l // 2 :], y[k // 2 :], m[k // 2 :, l // 2 :])
819
+
820
+ return m
data_preprocess/cubicasa5k/run.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # create COCO-style dataset for CubiCasa5k
2
+ python create_coco_cc5k.py --data_root=data/cubicasa5k/ \
3
+ --output=data/coco_cubicasa5k_nowalls_v4/ \
4
+ --disable_wd2line
5
+
6
+ # Split example has more than 1 floorplan into separate samples
7
+ python floorplan_extraction.py \
8
+ --data_root data/coco_cubicasa5k_nowalls_v4/ \
9
+ --output data/coco_cubicasa5k_nowalls_v4-1_refined/
10
+
11
+ # Merge individual JSONs into single JSON file per split (train/val/test)
12
+ # This must be done after floorplan_extraction.py
13
+ python combine_json.py \
14
+ --input data/coco_cubicasa5k_nowalls_v4-1_refined/ \
15
+ --output data/coco_cubicasa5k_nowalls_v4-1_refined/annotations/ \
data_preprocess/cubicasa5k/svg_utils.py ADDED
@@ -0,0 +1,746 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from logging import warning
3
+
4
+ import numpy as np
5
+ from skimage.draw import polygon
6
+ from svgpathtools import parse_path
7
+
8
+
9
+ def get_room_number(e, rooms):
10
+ name_list = e.getAttribute("class").split(" ")
11
+ room_type = name_list[1]
12
+ try:
13
+ return rooms[room_type]
14
+ except KeyError:
15
+ warning("Room type " + e.getAttribute("class") + " not defined.")
16
+ return rooms["Undefined"]
17
+
18
+
19
+ def get_icon_number(e, icons):
20
+ name_list = e.getAttribute("class").split(" ")
21
+ icon_type = name_list[1]
22
+
23
+ try:
24
+ return icons[icon_type]
25
+ except KeyError:
26
+ warning("Icon type " + e.getAttribute("class") + " not defined.")
27
+ return icons["Misc"]
28
+
29
+
30
+ def get_icon(ee):
31
+ parent_transform = None
32
+ if ee.parentNode.getAttribute("class") == "FixedFurnitureSet":
33
+ parent_transform = ee.parentNode.getAttribute("transform")
34
+ strings = parent_transform.split(",")
35
+ a_p = float(strings[0][7:])
36
+ b_p = float(strings[1])
37
+ c_p = float(strings[2])
38
+ d_p = float(strings[3])
39
+ e_p = float(strings[-2])
40
+ f_p = float(strings[-1][:-1])
41
+ M_p = np.array([[a_p, c_p, e_p], [b_p, d_p, f_p], [0, 0, 1]])
42
+
43
+ transform = ee.getAttribute("transform")
44
+ strings = transform.split(",")
45
+ a = float(strings[0][7:])
46
+ b = float(strings[1])
47
+ c = float(strings[2])
48
+ d = float(strings[3])
49
+ e = float(strings[-2])
50
+ f = float(strings[-1][:-1])
51
+
52
+ M = np.array([[a, c, e], [b, d, f], [0, 0, 1]])
53
+
54
+ X = np.array([])
55
+ Y = np.array([])
56
+
57
+ try:
58
+ toilet = next(p for p in ee.childNodes if p.nodeName == "g" and p.getAttribute("class") == "BoundaryPolygon")
59
+
60
+ for p in toilet.childNodes:
61
+ if p.nodeName == "polygon":
62
+ X, Y = get_icon_polygon(p)
63
+ break
64
+ else:
65
+ x_all, y_all = get_corners(toilet)
66
+ points = np.column_stack((x_all, y_all))
67
+
68
+ X, Y = get_max_corners(points)
69
+ # if p.nodeName == "path":
70
+ # X, Y = get_icon_path(p)
71
+
72
+ except StopIteration:
73
+ X, Y = make_boudary_polygon(ee)
74
+
75
+ if len(X) < 4:
76
+ return None, None, X, Y
77
+
78
+ if parent_transform is not None:
79
+ for i in range(len(X)):
80
+ v = np.matrix([[X[i]], [Y[i]], [1]])
81
+ vv = np.matmul(M, v)
82
+ new_x, new_y, _ = np.round(np.matmul(M_p, vv))
83
+ X[i] = new_x
84
+ Y[i] = new_y
85
+ else:
86
+ for i in range(len(X)):
87
+ v = np.matrix([[X[i]], [Y[i]], [1]])
88
+ vv = np.matmul(M, v)
89
+ new_x, new_y, _ = np.round(vv)
90
+ X[i] = new_x
91
+ Y[i] = new_y
92
+
93
+ rr, cc = polygon(Y, X)
94
+
95
+ return rr, cc, X, Y
96
+
97
+
98
+ def get_corners(g):
99
+ x_all, y_all = [], []
100
+ for pol in g.childNodes:
101
+ if pol.nodeName == "polygon":
102
+ x, y = get_icon_polygon(pol)
103
+ x_all = np.append(x_all, x)
104
+ y_all = np.append(y_all, y)
105
+ elif pol.nodeName == "path":
106
+ x, y = get_icon_path(pol)
107
+ x_all = np.append(x_all, x)
108
+ y_all = np.append(y_all, y)
109
+ elif pol.nodeName == "rect":
110
+ x = pol.getAttribute("x")
111
+ if x == "":
112
+ x = 1.0
113
+ else:
114
+ x = float(x)
115
+ y = pol.getAttribute("y")
116
+ if y == "":
117
+ y = 1.0
118
+ else:
119
+ y = float(y)
120
+ x_all = np.append(x_all, x)
121
+ y_all = np.append(y_all, y)
122
+ w = float(pol.getAttribute("width"))
123
+ h = float(pol.getAttribute("height"))
124
+ x_all = np.append(x_all, x + w)
125
+ y_all = np.append(y_all, y + h)
126
+
127
+ return x_all, y_all
128
+
129
+
130
+ def get_max_corners(points):
131
+ if len(points) == 0:
132
+ return [], []
133
+
134
+ minx, miny = float("inf"), float("inf")
135
+ maxx, maxy = float("-inf"), float("-inf")
136
+ for x, y in points:
137
+ # Set min coords
138
+ if x < minx:
139
+ minx = x
140
+ if y < miny:
141
+ miny = y
142
+ # Set max coords
143
+ if x > maxx:
144
+ maxx = x
145
+ elif y > maxy:
146
+ maxy = y
147
+
148
+ X = np.array([minx, maxx, maxx, minx])
149
+ Y = np.array([miny, miny, maxy, maxy])
150
+
151
+ return X, Y
152
+
153
+
154
+ def make_boudary_polygon(pol):
155
+ g_gen = (c for c in pol.childNodes if c.nodeName == "g")
156
+
157
+ x_all, y_all = [], []
158
+ for g in g_gen:
159
+ x, y = get_corners(g)
160
+ x_all = np.append(x_all, x)
161
+ y_all = np.append(y_all, y)
162
+
163
+ points = np.column_stack((x_all, y_all))
164
+ X, Y = get_max_corners(points)
165
+
166
+ return X, Y
167
+
168
+
169
+ def get_icon_path(pol):
170
+ path = pol.getAttribute("d")
171
+ try:
172
+ path_alt = parse_path(path)
173
+ minx, maxx, miny, maxy = path_alt.bbox()
174
+ except ValueError as e:
175
+ print("Error handled")
176
+ print(e)
177
+ return np.array([]), np.array([])
178
+
179
+ X = np.array([minx, maxx, maxx, minx])
180
+ Y = np.array([miny, miny, maxy, maxy])
181
+
182
+ if np.unique(X).size == 1 or np.unique(Y).size == 1:
183
+ return np.array([]), np.array([])
184
+
185
+ return X, Y
186
+
187
+
188
+ def get_icon_polygon(pol):
189
+ points = pol.getAttribute("points").split(" ")
190
+
191
+ return get_XY(points)
192
+
193
+
194
+ def get_XY(points):
195
+ if points[-1] == "":
196
+ points = points[:-1]
197
+
198
+ if points[0] == "":
199
+ points = points[1:]
200
+
201
+ X, Y = np.array([]), np.array([])
202
+ i = 0
203
+ for a in points:
204
+ if "," in a:
205
+ if len(a) == 2:
206
+ x, y = a.split(",")
207
+ else:
208
+ num_list = a.split(",")
209
+ x, y = num_list[0], num_list[1]
210
+ X = np.append(X, np.round(float(x)))
211
+ Y = np.append(Y, np.round(float(y)))
212
+ else:
213
+ # if no comma every other is x and every other is y
214
+ if i % 2:
215
+ Y = np.append(Y, float(a))
216
+ else:
217
+ X = np.append(X, float(a))
218
+
219
+ i += 1
220
+
221
+ return X, Y
222
+
223
+
224
+ def get_points(e):
225
+ pol = next(p for p in e.childNodes if p.nodeName == "polygon")
226
+ points = pol.getAttribute("points").split(" ")
227
+ points = points[:-1]
228
+
229
+ X, Y = np.array([]), np.array([])
230
+ for a in points:
231
+ x, y = a.split(",")
232
+ X = np.append(X, np.round(float(x)))
233
+ Y = np.append(Y, np.round(float(y)))
234
+
235
+ return X, Y
236
+
237
+
238
+ def get_direction(X, Y):
239
+ max_diff_X = abs(max(X) - min(X))
240
+ max_diff_Y = abs(max(Y) - min(Y))
241
+
242
+ if max_diff_X > max_diff_Y:
243
+ return "H" # horizontal
244
+ else:
245
+ return "V" # vertical
246
+
247
+
248
+ def get_polygon(e):
249
+ pol = next(p for p in e.childNodes if p.nodeName == "polygon")
250
+ points = pol.getAttribute("points").split(" ")
251
+ points = points[:-1]
252
+
253
+ X, Y = np.array([]), np.array([])
254
+ for a in points:
255
+ y, x = a.split(",")
256
+ X = np.append(X, np.round(float(x)))
257
+ Y = np.append(Y, np.round(float(y)))
258
+
259
+ rr, cc = polygon(X, Y)
260
+
261
+ return rr, cc
262
+
263
+
264
+ def calc_distance(point_1, point_2):
265
+ return math.sqrt(math.pow(point_1[0] - point_2[0], 2) + math.pow(point_1[1] - point_2[1], 2))
266
+
267
+
268
+ def calc_center(points):
269
+ return list(np.mean(np.array(points), axis=0))
270
+
271
+
272
+ def get_gaussian2D(ndim, sigma=0.25):
273
+ over_sigmau = 1.0 / (sigma * ndim)
274
+ over_sigmav = 1.0 / (sigma * ndim)
275
+ dst_data = np.zeros((ndim, ndim))
276
+
277
+ mean_u = 0.5 * ndim + 0.5
278
+ mean_v = 0.5 * ndim + 0.5
279
+
280
+ for v in range(ndim):
281
+ for u in range(ndim):
282
+ du = (u + 1 - mean_u) * over_sigmau
283
+ dv = (v + 1 - mean_v) * over_sigmav
284
+ value = np.exp(-0.5 * (du * du + dv * dv))
285
+ dst_data[v][u] = value
286
+
287
+ return dst_data
288
+
289
+
290
+ def draw_junction(index, point, width, height, axes):
291
+ lineLength = 15
292
+ lineWidth = 7
293
+ x, y = point[0]
294
+ axes.text(x, y, str(index), fontsize=15, color="k")
295
+ ###########################
296
+ # o
297
+ # | #6488ea soft blue
298
+ # | drawcode = [1,1]
299
+ #
300
+ ###########################
301
+ if point[2][1] == 1 and point[2][2] == 1:
302
+ axes.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="#6488ea")
303
+ ###########################
304
+ #
305
+ # ---o #6241c7 bluey purple
306
+ # drawcode = [1,2]
307
+ #
308
+ ###########################
309
+ elif point[2][1] == 1 and point[2][2] == 2:
310
+ axes.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="#6241c7")
311
+ ###########################
312
+ # |
313
+ # | drawcode = [1,3]
314
+ # o #056eee cerulean blue
315
+ #
316
+ ###########################
317
+ elif point[2][1] == 1 and point[2][2] == 3:
318
+ axes.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="#056eee")
319
+ ###########################
320
+ #
321
+ # drawcode = [1,4]
322
+ #
323
+ # o--- #004577 prussian blue
324
+ #
325
+ ###########################
326
+ elif point[2][1] == 1 and point[2][2] == 4:
327
+ axes.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="#004577")
328
+ ###########################
329
+ #
330
+ # |--- drawcode = [2,3]
331
+ # |
332
+ #
333
+ ###########################
334
+ elif point[2][1] == 2 and point[2][2] == 3:
335
+ axes.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="#04d8b2")
336
+ axes.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="#04d8b2")
337
+ ###########################
338
+ #
339
+ # ---|
340
+ # | drawcode = [2,4]
341
+ #
342
+ ###########################
343
+ elif point[2][1] == 2 and point[2][2] == 4:
344
+ axes.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="#cdfd02")
345
+ axes.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="#cdfd02")
346
+ ###########################
347
+ # |
348
+ # ---| drawcode = [2,1]
349
+ #
350
+ #
351
+ ###########################
352
+ elif point[2][1] == 2 and point[2][2] == 1:
353
+ axes.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="#ff81c0")
354
+ axes.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="#ff81c0")
355
+ ###########################
356
+ #
357
+ # |
358
+ # | drawcode = [2,2]
359
+ # --
360
+ #
361
+ ###########################
362
+ elif point[2][1] == 2 and point[2][2] == 2:
363
+ axes.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="#f97306")
364
+ axes.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="#f97306")
365
+ ###########################
366
+ #
367
+ # |
368
+ # |--- drawcode = [3,4]
369
+ # |
370
+ #
371
+ ###########################
372
+ elif point[2][1] == 3 and point[2][2] == 4:
373
+ axes.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="b")
374
+ axes.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="b")
375
+ axes.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="b")
376
+ ###########################
377
+ #
378
+ # ---
379
+ # | drawcode = [3,1]
380
+ # |
381
+ #
382
+ ###########################
383
+ elif point[2][1] == 3 and point[2][2] == 1:
384
+ axes.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="y")
385
+ axes.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="y")
386
+ axes.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="y")
387
+ ###########################
388
+ #
389
+ # |
390
+ # ---| drawcode = [3,2]
391
+ # |
392
+ #
393
+ ###########################
394
+ elif point[2][1] == 3 and point[2][2] == 2:
395
+ axes.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="r")
396
+ axes.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="r")
397
+ axes.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="r")
398
+ ###########################
399
+ #
400
+ # |
401
+ # | drawcode = [3,3]
402
+ # ---
403
+ #
404
+ ###########################
405
+ elif point[2][1] == 3 and point[2][2] == 3:
406
+ axes.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="m")
407
+ axes.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="m")
408
+ axes.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="m")
409
+ ###########################
410
+ #
411
+ # |
412
+ # --- drawcode = [4,1]
413
+ # |
414
+ #
415
+ ###########################
416
+ elif point[2][1] == 4 and point[2][2] == 1:
417
+ axes.plot([x, min(x + lineLength, width - 1)], [y, y], linewidth=lineWidth, color="k")
418
+ axes.plot([x, max(x - lineLength, 0)], [y, y], linewidth=lineWidth, color="k")
419
+ axes.plot([x, x], [y, max(y - lineLength, 0)], linewidth=lineWidth, color="k")
420
+ axes.plot([x, x], [y, min(y + lineLength, height - 1)], linewidth=lineWidth, color="k")
421
+
422
+
423
+ class Wall:
424
+ def __init__(self, id, end_points, direction, width, name):
425
+ self.id = id
426
+ self.name = name
427
+ self.end_points = end_points
428
+ self.direction = direction
429
+ self.max_width = width
430
+ self.min_width = width
431
+
432
+ def change_end_points(self):
433
+ if self.direction == "V":
434
+ self.end_points[0][0] = np.mean(np.array(self.min_coord))
435
+ self.end_points[1][0] = self.end_points[0][0]
436
+ elif self.direction == "H":
437
+ self.end_points[0][1] = np.mean(np.array(self.min_coord))
438
+ self.end_points[1][1] = self.end_points[0][1]
439
+
440
+ def get_length(self, end_points):
441
+ return calc_distance(end_points[0], end_points[1])
442
+
443
+
444
+ class LineWall(Wall):
445
+ def __init__(self, id, end_points, direction, width, name):
446
+ Wall.__init__(self, id, end_points, direction, width, name)
447
+
448
+
449
+ class PolygonWall(Wall):
450
+ def __init__(self, e, id, shape=None):
451
+ self.id = id
452
+ self.name = e.getAttribute("id")
453
+ self.X, self.Y = self.get_points(e)
454
+ if abs(max(self.X) - min(self.X)) < 4 or abs(max(self.Y) - min(self.Y)) < 4:
455
+ # wall is too small and we ignore it.
456
+ raise ValueError("small wall")
457
+ if shape:
458
+ self.X = np.clip(self.X, 0, shape[1])
459
+ self.Y = np.clip(self.Y, 0, shape[0])
460
+ # self.X, self.Y = self.sort_X_Y(self.X, self.Y)
461
+ self.rr, self.cc = polygon(self.Y, self.X)
462
+ direction = self.get_direction(self.X, self.Y)
463
+ end_points = self.get_end_points(self.X, self.Y, direction)
464
+ self.min_width = self.get_width(self.X, self.Y, direction)
465
+ self.max_width = self.min_width
466
+
467
+ Wall.__init__(self, id, end_points, direction, self.max_width, self.name)
468
+ self.length = self.get_length(self.end_points)
469
+ self.center = self.get_center(self.X, self.Y)
470
+ self.min_coord, self.max_coord = self.get_width_coods(self.X, self.Y)
471
+
472
+ def get_points(self, e):
473
+ pol = next(p for p in e.childNodes if p.nodeName == "polygon")
474
+ points = pol.getAttribute("points").split(" ")
475
+ points = points[:-1]
476
+
477
+ X, Y = np.array([]), np.array([])
478
+ for a in points:
479
+ x, y = a.split(",")
480
+ X = np.append(X, np.round(float(x)))
481
+ Y = np.append(Y, np.round(float(y)))
482
+
483
+ return X, Y
484
+
485
+ def get_direction(self, X, Y):
486
+ max_diff_X = abs(max(X) - min(X))
487
+ max_diff_Y = abs(max(Y) - min(Y))
488
+
489
+ if max_diff_X > max_diff_Y:
490
+ return "H" # horizontal
491
+ else:
492
+ return "V" # vertical
493
+
494
+ def get_center(self, X, Y):
495
+ return np.mean(X), np.mean(Y)
496
+
497
+ def get_width(self, X, Y, direction):
498
+ _, _, p1, p2 = self._get_min_points(X, Y)
499
+
500
+ if direction == "H":
501
+ return (abs(p1[0][1] - p1[1][1]) + abs(p2[0][1] - p2[1][1])) / 2
502
+ elif "V":
503
+ return (abs(p1[0][0] - p1[1][0]) + abs(p2[0][0] - p2[1][0])) / 2
504
+
505
+ def _width(self, values):
506
+ temp = values.tolist() if type(values) is not list else values
507
+
508
+ mean_1 = min(temp)
509
+ mean_2 = max(temp)
510
+
511
+ return abs(mean_1 - mean_2)
512
+
513
+ def merge_possible(self, merged):
514
+ max_dist = max([self.max_width, merged.max_width])
515
+
516
+ if self.id == merged.id:
517
+ return False
518
+
519
+ # walls have to be in the same direction
520
+ if self.direction != merged.direction:
521
+ return False
522
+
523
+ # walls have too big width difference
524
+ if abs(self.max_width - merged.max_width) > merged.max_width:
525
+ return False
526
+
527
+ # If endpoints are near
528
+ # self up and left endpoint to merged down and right end point
529
+ dist1 = calc_distance(self.end_points[0], merged.end_points[1])
530
+ # self down and right endpoint to merged up and left end point
531
+ dist2 = calc_distance(self.end_points[1], merged.end_points[0])
532
+
533
+ if dist1 <= max_dist * 1.5 or dist2 <= max_dist * 1.5:
534
+ return True
535
+ else:
536
+ return False
537
+
538
+ def _get_overlap(self, a, b):
539
+ return max(0, min(a[1], b[1]) - max(a[0], b[0]))
540
+
541
+ def merge_walls(self, merged):
542
+ max_dist = max([self.max_width, merged.max_width])
543
+
544
+ if self.id == merged.id:
545
+ return None
546
+
547
+ # walls have to be in the same direction
548
+ if self.direction != merged.direction:
549
+ return None
550
+
551
+ # If endpoints are near
552
+ # self up and left endpoint to merged down and right end point
553
+ dist1 = calc_distance(self.end_points[0], merged.end_points[1])
554
+ # self down and right endpoint to merged up and left end point
555
+ dist2 = calc_distance(self.end_points[1], merged.end_points[0])
556
+
557
+ if dist1 <= max_dist * 1.5:
558
+ if self._get_overlap(self.min_coord, merged.min_coord) <= 0:
559
+ return None
560
+ # merged is on top or on left
561
+ return self.do_merge(merged, 0)
562
+ elif dist2 <= max_dist * 1.5:
563
+ if self._get_overlap(self.min_coord, merged.min_coord) <= 0:
564
+ return None
565
+ # merged is on down or on right
566
+ return self.do_merge(merged, 1)
567
+ else:
568
+ return None
569
+
570
+ def _get_min_points(self, X, Y):
571
+ assert len(X) == 4 and len(Y) == 4
572
+ length = len(X)
573
+ min_dist1 = np.inf
574
+ min_dist2 = np.inf
575
+ point1 = None
576
+ point2 = None
577
+ corners1 = None
578
+ corners2 = None
579
+
580
+ for i in range(length):
581
+ x1, y1 = X[i], Y[i]
582
+ x2, y2 = X[(i + 1) % 4], Y[(i + 1) % 4]
583
+
584
+ dist = np.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)
585
+ if dist < min_dist1:
586
+ point2 = point1
587
+ point1 = np.array([(x1 + x2) / 2, (y1 + y2) / 2])
588
+ min_dist2 = min_dist1
589
+ min_dist1 = dist
590
+ corners2 = corners1
591
+ corners1 = np.array([[x1, y1], [x2, y2]])
592
+ elif dist <= min_dist2:
593
+ point2 = np.array([(x1 + x2) / 2, (y1 + y2) / 2])
594
+ min_dist2 = dist
595
+ corners2 = np.array([[x1, y1], [x2, y2]])
596
+
597
+ return point1, point2, corners1, corners2
598
+
599
+ def get_end_points(self, X, Y, direction):
600
+ point1, point2, _, _ = self._get_min_points(X, Y)
601
+
602
+ if point1[0] != point2[0] or point1[1] != point2[1]:
603
+ if abs(point1[0] - point2[0]) > abs(point1[1] - point2[1]):
604
+ # horizontal
605
+ point1[1] = point1[1] + point2[1] / 2.0
606
+ point2[1] = point1[1]
607
+ # point1[1] = int(np.round(point1[1]))
608
+ # point2[1] = int(np.round(point2[1]))
609
+ else:
610
+ # vertical
611
+ point1[0] = point1[0] + point2[0] / 2.0
612
+ point2[0] = point1[0]
613
+ # point1[0] = int(np.round(point1[0]))
614
+ # point2[0] = int(np.round(point2[0]))
615
+
616
+ return self.sort_end_points(direction, point1, point2)
617
+
618
+ def sort_end_points(self, direction, point1, point2):
619
+ if direction == "V":
620
+ if point1[1] < point2[1]:
621
+ return np.array([point1, point2])
622
+ else:
623
+ return np.array([point2, point1])
624
+ else:
625
+ if point1[0] < point2[0]:
626
+ return np.array([point1, point2])
627
+ else:
628
+ return np.array([point2, point1])
629
+
630
+ def do_merge(self, merged, direction):
631
+ # update width
632
+ self.max_width = max([self.max_width, merged.max_width])
633
+ self.min_width = min([self.min_width, merged.min_width])
634
+
635
+ # update polygon
636
+ self.X = np.concatenate((self.X, merged.X))
637
+ self.Y = np.concatenate((self.Y, merged.Y))
638
+
639
+ # update width coordinates
640
+ self.max_coord = self.get_max_width_coord(merged)
641
+ self.min_coord = self.get_min_width_coord(merged)
642
+
643
+ if direction == 0:
644
+ # merged wall is up or left to the original wall
645
+ self.end_points = np.array([merged.end_points[0], self.end_points[1]])
646
+ else:
647
+ # merged wall is down or right to the original wall
648
+ self.end_points = np.array([self.end_points[0], merged.end_points[1]])
649
+
650
+ self.length = self.get_length(self.end_points)
651
+
652
+ return self
653
+
654
+ def get_max_width_coord(self, merged):
655
+ width_1 = abs(self.max_coord[0] - self.max_coord[1])
656
+ width_2 = abs(merged.max_coord[0] - merged.max_coord[1])
657
+ return self.max_coord if width_1 > width_2 else merged.max_coord
658
+
659
+ def get_min_width_coord(self, merged):
660
+ width_1 = max(merged.min_coord[0], self.min_coord[0])
661
+ # width_1 = abs(self.min_coord[0] - self.min_coord[1])
662
+ width_2 = min(merged.min_coord[1], self.min_coord[1])
663
+ # width_2 = abs(merged.min_coord[0] - merged.min_coord[1])
664
+ # return self.min_coord if width_1 < width_2 else merged.min_coord
665
+ return [width_1, width_2]
666
+
667
+ def get_width_coods(self, X, Y):
668
+ if self.direction == "H":
669
+ dist_1 = abs(Y[0] - Y[2])
670
+ dist_2 = abs(Y[1] - Y[3])
671
+ if dist_1 < dist_2:
672
+ return [Y[0], Y[2]], [Y[1], Y[3]]
673
+ else:
674
+ return [Y[1], Y[3]], [Y[0], Y[2]]
675
+
676
+ elif self.direction == "V":
677
+ dist_1 = abs(X[0] - X[3])
678
+ dist_2 = abs(X[1] - X[2])
679
+ if dist_1 < dist_2:
680
+ return [X[0], X[3]], [X[1], X[2]]
681
+ else:
682
+ return [X[1], X[2]], [X[0], X[3]]
683
+
684
+ def sort_X_Y(self, X, Y):
685
+ max_x = max(X)
686
+ min_x = min(X)
687
+ max_y = max(Y)
688
+ min_y = min(Y)
689
+ res_X, res_Y = [0] * 4, [0] * 4
690
+ # top left 0, top right 1, bottom left 2, bottom right 3
691
+ directions = [[min_x, min_y], [max_x, min_y], [min_x, max_y], [max_x, max_y]]
692
+ length = len(X)
693
+ for i in range(length):
694
+ min_dist = 1000000
695
+ direction_candidate = None
696
+ for j, direc in enumerate(directions):
697
+ dist = calc_distance([X[i], Y[i]], direc)
698
+ if dist < min_dist:
699
+ min_dist = dist
700
+ direction_candidate = j
701
+
702
+ res_X[direction_candidate] = X[i]
703
+ res_Y[direction_candidate] = Y[i]
704
+
705
+ return res_X, res_Y
706
+
707
+ def wall_is_pillar(self, avg_wall_width):
708
+ if self.max_width > avg_wall_width:
709
+ if self.length < 3 * self.max_width:
710
+ return True
711
+
712
+ return False
713
+
714
+ def split_pillar_wall(self, ids, avg_wall_width):
715
+ half = avg_wall_width / 3.0
716
+ end_points = [[[0, 0], [0, 0]], [[0, 0], [0, 0]], [[0, 0], [0, 0]], [[0, 0], [0, 0]]]
717
+ self.X[np.argmax(self.X)] = max(self.X) - half
718
+ self.X[np.argmax(self.X)] = max(self.X) - half
719
+ self.X[np.argmin(self.X)] = min(self.X) + half
720
+ self.X[np.argmin(self.X)] = min(self.X) + half
721
+ self.Y[np.argmax(self.Y)] = max(self.Y) - half
722
+ self.Y[np.argmax(self.Y)] = max(self.Y) - half
723
+ self.Y[np.argmin(self.Y)] = min(self.Y) + half
724
+ self.Y[np.argmin(self.Y)] = min(self.Y) + half
725
+ for i in range(4):
726
+ x = self.X[i]
727
+ y = self.Y[i]
728
+ end = [x, y]
729
+ j = i % 2
730
+ end_points[i][j] = end
731
+ end_points[(i + 3) % 4][j] = end
732
+
733
+ walls = []
734
+ for i, e in enumerate(end_points):
735
+ if abs(e[0][1] - e[1][1]) > abs(e[0][0] - e[1][0]):
736
+ # vertical wall
737
+ direction = "V"
738
+ else:
739
+ # horizontal wall
740
+ direction = "H"
741
+
742
+ e = self.sort_end_points(direction, e[0], e[1])
743
+ wall = LineWall(ids + i, e, direction, avg_wall_width / 2.0, self.name)
744
+ walls.append(wall)
745
+
746
+ return walls
data_preprocess/raster2graph/combine_json.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import json
3
+ import os
4
+ import shutil
5
+ from pathlib import Path
6
+
7
+
8
+ def combine_json_files(input_pattern, data_path, split_type, output_file, output_image_dir, start_image_id=0):
9
+ """
10
+ Combines multiple COCO-style JSON annotation files into a single file.
11
+
12
+ Args:
13
+ input_pattern: Glob pattern to match the input JSON files (e.g., "annotations/*.json")
14
+ output_file: Path to the output combined JSON file
15
+ """
16
+ os.makedirs(output_image_dir, exist_ok=True)
17
+
18
+ # Initialize combined data structure
19
+ combined_data = {"images": [], "annotations": [], "categories": []}
20
+
21
+ # Track image and annotation IDs to avoid duplicates
22
+ annotation_ids_seen = set()
23
+
24
+ next_image_id = start_image_id
25
+ next_annotation_id = 0
26
+ skip_file_list = []
27
+ image_id_mapping = {}
28
+
29
+ # Find all matching JSON files
30
+ json_files = sorted(glob.glob(input_pattern))
31
+ print(f"Found {len(json_files)} JSON files to combine")
32
+
33
+ # Process each file
34
+ for i, json_file in enumerate(json_files):
35
+ print(f"Processing file {i + 1}/{len(json_files)}: {json_file}")
36
+
37
+ with open(json_file, "r") as f:
38
+ data = json.load(f)
39
+
40
+ # Store categories from the first file
41
+ if i == 0 and data.get("categories"):
42
+ combined_data["categories"] = data["categories"]
43
+
44
+ # empty annos
45
+ if len(data["annotations"]) == 0:
46
+ skip_file_list.append(data["images"][0]["id"])
47
+ continue
48
+
49
+ # Process images
50
+ for image in data.get("images", []):
51
+ if image["id"] not in image_id_mapping:
52
+ image_id_mapping[image["id"]] = next_image_id
53
+ else:
54
+ skip_file_list.append(image["id"])
55
+ continue
56
+ image["id"] = next_image_id
57
+ next_image_id += 1
58
+ # org_file_name = copy(image['file_name'])
59
+ image["file_name"] = str(image["id"]).zfill(6) + ".png"
60
+ org_file_name = os.path.basename(json_file).replace(".json", ".png")
61
+ if image["file_name"] != org_file_name and os.path.exists(f"{data_path}/{split_type}/{org_file_name}"):
62
+ shutil.copy(f"{data_path}/{split_type}/{org_file_name}", f"{output_image_dir}/{image['file_name']}")
63
+ combined_data["images"].append(image)
64
+
65
+ # Process annotations
66
+ for annotation in data.get("annotations", []):
67
+ annotation["id"] = next_annotation_id
68
+ next_annotation_id += 1
69
+ annotation["image_id"] = image_id_mapping[annotation["image_id"]]
70
+
71
+ annotation_ids_seen.add(annotation["id"])
72
+ combined_data["annotations"].append(annotation)
73
+
74
+ # Write combined data to output file
75
+ output_path = Path(output_file)
76
+ output_path.parent.mkdir(exist_ok=True, parents=True)
77
+
78
+ with open(output_file, "w") as f:
79
+ json.dump(combined_data, f, indent=2)
80
+
81
+ with open(output_path.parent / f"{output_path.name.split('.')[0]}_image_id_mapping.json", "w") as f:
82
+ json.dump(image_id_mapping, f, indent=2)
83
+
84
+ if len(skip_file_list):
85
+ with open(output_path.parent / f"{output_path.name.split('.')[0]}_skipped.txt", "w") as f:
86
+ f.write("\n".join([str(x) for x in skip_file_list]))
87
+
88
+ print(f"Combined data written to {output_file}")
89
+ print(f"Total images: {len(combined_data['images'])}")
90
+ print(f"Total annotations: {len(combined_data['annotations'])}")
91
+ print(f"Total categories: {len(combined_data['categories'])}")
92
+ print(f"Skipped images: {len(skip_file_list)}")
93
+
94
+ image_id_mapping_list = [[f"{k} {v}"] for k, v in image_id_mapping.items()] # Reverse mapping for easier lookup
95
+
96
+ return combined_data, image_id_mapping_list
97
+
98
+
99
+ if __name__ == "__main__":
100
+ import argparse
101
+
102
+ parser = argparse.ArgumentParser(description="Combine multiple COCO-style JSON annotation files")
103
+ parser.add_argument("--input", required=True, help="Glob pattern for input JSON files, e.g., 'annotations/*.json'")
104
+ parser.add_argument("--output", required=True, help="Output JSON file path")
105
+
106
+ args = parser.parse_args()
107
+
108
+ splits = ["train", "val", "test"]
109
+ for i, split in enumerate(splits):
110
+ if split == "train":
111
+ start_image_id = 0
112
+ else:
113
+ start_image_id += len(list(Path(f"{args.input}/{splits[i - 1]}").glob("*.png")))
114
+
115
+ _, image_id_mapping_list = combine_json_files(
116
+ f"{args.input}/{split}_jsons/*.json",
117
+ args.input,
118
+ split,
119
+ f"{args.output}/annotations/{split}.json",
120
+ output_image_dir=f"{args.output}/{split}",
121
+ start_image_id=start_image_id,
122
+ )
data_preprocess/raster2graph/combine_mapping_ids.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+
4
+ def generate_combined_mapping(file_mapping_path, image_id_mapping_path, output_path):
5
+ """
6
+ Generates a combined mapping file from an original filename mapping
7
+ and an image ID mapping.
8
+
9
+ Args:
10
+ file_mapping_path (str): Path to the text file mapping original filenames
11
+ to intermediate 6-digit IDs.
12
+ image_id_mapping_path (str): Path to the JSON file mapping intermediate
13
+ IDs to destination IDs.
14
+ output_path (str): Path where the new combined mapping file will be saved.
15
+ """
16
+ # 1. Read test_file_mapping.txt
17
+ org_fn_to_intermediate_id = {}
18
+ try:
19
+ with open(file_mapping_path, "r") as f:
20
+ for line in f:
21
+ parts = line.strip().split()
22
+ if len(parts) == 2:
23
+ org_fn = parts[0]
24
+ # Convert the 6-digit string ID to an integer for lookup
25
+ intermediate_id_str = parts[1]
26
+ # Remove leading zeros and convert to int
27
+ intermediate_id = int(intermediate_id_str)
28
+ org_fn_to_intermediate_id[org_fn] = intermediate_id
29
+ except FileNotFoundError:
30
+ print(f"Error: The file '{file_mapping_path}' was not found.")
31
+ return
32
+ except Exception as e:
33
+ print(f"Error reading '{file_mapping_path}': {e}")
34
+ return
35
+
36
+ # 2. Read test_image_id_mapping.json
37
+ intermediate_id_to_dst_fn = {}
38
+ try:
39
+ with open(image_id_mapping_path, "r") as f:
40
+ image_id_data = json.load(f)
41
+ for key, value in image_id_data.items():
42
+ # Keys in JSON are strings, convert to int for consistency
43
+ intermediate_id_to_dst_fn[int(key)] = value
44
+ except FileNotFoundError:
45
+ print(f"Error: The file '{image_id_mapping_path}' was not found.")
46
+ return
47
+ except json.JSONDecodeError:
48
+ print(f"Error: Could not decode JSON from '{image_id_mapping_path}'. Please ensure it's valid JSON.")
49
+ return
50
+ except Exception as e:
51
+ print(f"Error reading '{image_id_mapping_path}': {e}")
52
+ return
53
+
54
+ # 3. Create the combined mapping and write to output file
55
+ combined_mappings = []
56
+ found_mappings_count = 0
57
+ for org_fn, intermediate_id in org_fn_to_intermediate_id.items():
58
+ if intermediate_id in intermediate_id_to_dst_fn:
59
+ dst_fn = intermediate_id_to_dst_fn[intermediate_id]
60
+ combined_mappings.append(f"{org_fn} {dst_fn}")
61
+ found_mappings_count += 1
62
+ else:
63
+ # Optionally, you can print a warning for IDs not found
64
+ print(f"Warning: Intermediate ID '{intermediate_id}' for '{org_fn}' not found in image ID mapping.")
65
+
66
+ try:
67
+ with open(output_path, "w") as f:
68
+ for mapping_line in combined_mappings:
69
+ f.write(mapping_line + "\n")
70
+ print(f"\nSuccessfully generated combined mapping to '{output_path}'.")
71
+ print(f"Total original filenames processed: {len(org_fn_to_intermediate_id)}")
72
+ print(f"Total combined mappings written: {found_mappings_count}")
73
+ except Exception as e:
74
+ print(f"Error writing to output file '{output_path}': {e}")
75
+
76
+
77
+ # Define file paths
78
+ file_mapping_path = "data/R2G_hr_dataset_processed/test_file_mapping.txt"
79
+ image_id_mapping_path = "data/R2G_hr_dataset_processed_v1/annotations/test_image_id_mapping.json"
80
+ output_mapping_path = "data/R2G_hr_dataset_processed_v1/annotations/test_combined_mapping.txt"
81
+
82
+ # Run the mapping function
83
+ generate_combined_mapping(file_mapping_path, image_id_mapping_path, output_mapping_path)
84
+
85
+ # You can optionally print the content of the generated file to verify
86
+ print("\n--- Content of combined_mapping.txt ---")
87
+ try:
88
+ with open(output_mapping_path, "r") as f:
89
+ print(f.read())
90
+ except FileNotFoundError:
91
+ print("Output file was not created.")
92
+
93
+ # Clean up dummy files (optional)
94
+ # os.remove(file_mapping_path)
95
+ # os.remove(image_id_mapping_path)
data_preprocess/raster2graph/convert_to_coco.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ import sys
4
+
5
+ print(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
6
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
7
+
8
+ import argparse
9
+ import json
10
+ import shutil
11
+ from multiprocessing import Pool
12
+
13
+ import cv2
14
+ import matplotlib.pyplot as plt
15
+ import numpy as np
16
+ import torch
17
+ from datasets.dataset import MyDataset
18
+ from matplotlib.patches import Patch
19
+ from shapely.geometry import Polygon
20
+ from tqdm import tqdm
21
+ from util.data_utils import edge_inside
22
+ from util.graph_utils import get_cycle_basis_and_semantic, tensors_to_graphs_batch
23
+
24
+ mean = [0.920, 0.913, 0.891]
25
+ std = [0.214, 0.216, 0.228]
26
+
27
+ ID2CLASS = {
28
+ 0: "unknown",
29
+ 1: "living_room",
30
+ 2: "kitchen",
31
+ 3: "bedroom",
32
+ 4: "bathroom",
33
+ 5: "restroom",
34
+ 6: "balcony",
35
+ 7: "closet",
36
+ 8: "corridor",
37
+ 9: "washing_room",
38
+ 10: "PS",
39
+ 11: "outside",
40
+ # 12: 'wall'
41
+ }
42
+
43
+
44
+ def plot_room_map(preds, room_map, room_id=0, im_size=256, plot_text=True):
45
+ """Draw room polygons overlaid on the density map"""
46
+ centroid_x = int(np.mean(preds[:, 0]))
47
+ centroid_y = int(np.mean(preds[:, 1]))
48
+
49
+ # Get text size to create a background box
50
+ font = cv2.FONT_HERSHEY_SIMPLEX
51
+ font_scale = 0.3
52
+ thickness = 1
53
+ text = str(room_id)
54
+ (text_width, text_height), baseline = cv2.getTextSize(text, font, font_scale, thickness)
55
+ border_color = (252, 252, 0)
56
+
57
+ for i, corner in enumerate(preds):
58
+ if i == len(preds) - 1:
59
+ cv2.line(
60
+ room_map,
61
+ (round(corner[0]), round(corner[1])),
62
+ (round(preds[0][0]), round(preds[0][1])),
63
+ border_color,
64
+ 2,
65
+ )
66
+ else:
67
+ cv2.line(
68
+ room_map,
69
+ (round(corner[0]), round(corner[1])),
70
+ (round(preds[i + 1][0]), round(preds[i + 1][1])),
71
+ border_color,
72
+ 2,
73
+ )
74
+ cv2.circle(room_map, (round(corner[0]), round(corner[1])), 2, (0, 0, 255), 2)
75
+ # cv2.putText(room_map, str(i), (round(corner[0]), round(corner[1])), cv2.FONT_HERSHEY_SIMPLEX,
76
+ # 0.4, (0, 255, 0), 1, cv2.LINE_AA)
77
+
78
+ # Draw white background box with transparency
79
+ # overlay = room_map.copy()
80
+ # cv2.addWeighted(overlay, 0.7, room_map, 0.3, 0, room_map) # 70% opacity
81
+
82
+ # Draw text
83
+ if plot_text:
84
+ cv2.rectangle(
85
+ room_map,
86
+ (centroid_x - text_width // 2 - 2, centroid_y - text_height // 2 - 2),
87
+ (centroid_x + text_width // 2 + 2, centroid_y + text_height // 2 + 2),
88
+ (255, 255, 255), # (0, 0, 0),
89
+ -1,
90
+ ) # Filled rectangle
91
+ cv2.putText(
92
+ room_map,
93
+ text,
94
+ (centroid_x - text_width // 2, centroid_y + text_height // 2),
95
+ font,
96
+ font_scale,
97
+ (0, 100, 0),
98
+ thickness,
99
+ )
100
+
101
+ return room_map
102
+
103
+
104
+ def plot_density_map(sample, image_size, room_polys, pred_room_label_per_scene, plot_text=True):
105
+ if not isinstance(sample, np.ndarray):
106
+ density_map = np.transpose(sample.cpu().numpy(), [1, 2, 0])
107
+ # # Convert to grayscale if not already
108
+ # if density_map.shape[2] > 1:
109
+ # density_map = cv2.cvtColor(density_map, cv2.COLOR_RGB2GRAY)[:, :, np.newaxis]
110
+ else:
111
+ density_map = sample
112
+ if density_map.shape[2] == 3:
113
+ density_map = density_map * (image_size - 1)
114
+ else:
115
+ density_map = np.repeat(density_map, 3, axis=2) * (image_size - 1)
116
+ pred_room_map = np.zeros([image_size, image_size, 3])
117
+
118
+ for room_poly, room_id in zip(room_polys, pred_room_label_per_scene):
119
+ pred_room_map = plot_room_map(
120
+ np.array(room_poly), pred_room_map, room_id, im_size=image_size, plot_text=plot_text
121
+ )
122
+
123
+ alpha = 0.4 # Adjust for desired transparency
124
+ pred_room_map = cv2.addWeighted(density_map.astype(np.uint8), alpha, pred_room_map.astype(np.uint8), 1 - alpha, 0)
125
+ return pred_room_map
126
+
127
+
128
+ def is_clockwise(points):
129
+ # points is a list of 2d points.
130
+ assert len(points) > 0
131
+ s = 0.0
132
+ for p1, p2 in zip(points, points[1:] + [points[0]]):
133
+ s += (p2[0] - p1[0]) * (p2[1] + p1[1])
134
+ return s > 0.0
135
+
136
+
137
+ def resort_corners(corners):
138
+ # re-find the starting point and sort corners clockwisely
139
+ x_y_square_sum = corners[:, 0] ** 2 + corners[:, 1] ** 2
140
+ start_corner_idx = np.argmin(x_y_square_sum)
141
+
142
+ corners_sorted = np.concatenate([corners[start_corner_idx:], corners[:start_corner_idx]])
143
+
144
+ ## sort points clockwise
145
+ if not is_clockwise(corners_sorted[:, :2].tolist()):
146
+ corners_sorted[1:] = np.flip(corners_sorted[1:], 0)
147
+
148
+ return corners
149
+
150
+
151
+ def create_coco_bounding_box(bb_x, bb_y, image_width, image_height, bound_pad=2):
152
+ bb_x = np.unique(bb_x)
153
+ bb_y = np.unique(bb_y)
154
+ bb_x_min = np.maximum(np.min(bb_x) - bound_pad, 0)
155
+ bb_y_min = np.maximum(np.min(bb_y) - bound_pad, 0)
156
+
157
+ bb_x_max = np.minimum(np.max(bb_x) + bound_pad, image_width - 1)
158
+ bb_y_max = np.minimum(np.max(bb_y) + bound_pad, image_height - 1)
159
+
160
+ bb_width = bb_x_max - bb_x_min
161
+ bb_height = bb_y_max - bb_y_min
162
+
163
+ coco_bb = [bb_x_min, bb_y_min, bb_width, bb_height]
164
+ return coco_bb
165
+
166
+
167
+ def prepare_dict():
168
+ save_dict = {"images": [], "annotations": [], "categories": []}
169
+ for key, value in ID2CLASS.items():
170
+ if key == 0:
171
+ continue
172
+ type_dict = {"supercategory": "room", "id": key, "name": value}
173
+ save_dict["categories"].append(type_dict)
174
+ return save_dict
175
+
176
+
177
+ def get_args_parser():
178
+ parser = argparse.ArgumentParser()
179
+ parser.add_argument(
180
+ "--dataset_path",
181
+ type=str,
182
+ required=True,
183
+ help="Path to the dataset directory",
184
+ )
185
+ parser.add_argument(
186
+ "--output_dir",
187
+ type=str,
188
+ required=True,
189
+ help="Path to the dataset directory",
190
+ )
191
+ # Add more arguments as needed
192
+ return parser
193
+
194
+
195
+ def visualize_room_polygons(room_polygons, room_classes, image_size=512, save_path="cubicasa_debug.png"):
196
+ """
197
+ Visualize the extracted room polygons.
198
+
199
+ Args:
200
+ room_polygons: Dictionary of room polygons as returned by extract_room_polygons
201
+ figsize: Figure size for the plot
202
+ """
203
+ # Set figure size to exactly 256x256 pixels
204
+ dpi = 100 # Standard screen DPI
205
+ figsize = (image_size / dpi, image_size / dpi) # Convert pixels to inches
206
+ class_names = [v for k, v in ID2CLASS.items()]
207
+
208
+ # Get unique classes from the mask
209
+ unique_classes = list(ID2CLASS.keys())
210
+
211
+ # Create a discrete colormap
212
+ cmap = plt.cm.get_cmap("gist_ncar", 256) # nipy_spectral
213
+ norm = np.linspace(0, 1, 13) # int(max(unique_classes))+1
214
+
215
+ fig = plt.figure(figsize=figsize, dpi=dpi)
216
+ ax = fig.add_axes([0, 0, 1, 1])
217
+ ax.set_xlim(0, image_size)
218
+ ax.set_ylim(0, image_size)
219
+ ax.set_aspect("equal")
220
+ ax.axis("off")
221
+
222
+ # Plot each room polygon and fill with color
223
+ for polygon, room_cls in zip(room_polygons, room_classes):
224
+ polygon_array = np.array(polygon).copy()
225
+ polygon_array[:, 1] = image_size - 1 - polygon_array[:, 1] # flip
226
+ # Fill the polygon with its class color
227
+ color = cmap(norm[int(room_cls)])
228
+ ax.fill(polygon_array[:, 0], polygon_array[:, 1], color=color, alpha=0.4, zorder=1)
229
+ # Draw the polygon border
230
+ ax.plot(polygon_array[:, 0], polygon_array[:, 1], "k-", linewidth=2, zorder=2)
231
+
232
+ # Add room ID label at the centroid
233
+ centroid_x = np.mean(polygon_array[:, 0])
234
+ centroid_y = np.mean(polygon_array[:, 1])
235
+ ax.text(
236
+ centroid_x,
237
+ centroid_y,
238
+ str(room_cls),
239
+ fontsize=12,
240
+ ha="center",
241
+ va="center",
242
+ bbox=dict(facecolor="white", alpha=0.7),
243
+ zorder=3,
244
+ )
245
+
246
+ # Create custom legend elements
247
+ legend_elements = []
248
+ for i, cls in enumerate(sorted(unique_classes)):
249
+ color = cmap(norm[int(cls)])
250
+ cls_name = f"{int(cls)}_{class_names[int(cls)]}"
251
+ legend_elements.append(Patch(facecolor=color, edgecolor="black", label=f"{cls_name}", alpha=0.6))
252
+ ax.legend(
253
+ handles=legend_elements,
254
+ loc="best",
255
+ title="Classes",
256
+ fontsize=10,
257
+ markerscale=1,
258
+ title_fontsize=12,
259
+ framealpha=0.5,
260
+ )
261
+
262
+ plt.tight_layout(pad=0)
263
+ fig.savefig(save_path, bbox_inches="tight", pad_inches=0)
264
+ plt.close()
265
+
266
+
267
+ def process_floorplan(image_set, split, source_data_path, save_dir, save_aux_dir, vis_fp=False):
268
+ img, target = image_set
269
+ img = img * torch.tensor(std)[:, None, None] + torch.tensor(mean)[:, None, None] # unnormalize
270
+ graph = tensors_to_graphs_batch([target["graph"]])
271
+ del target["graph"]
272
+
273
+ tgt_this_preds = []
274
+ tgt_this_edges = []
275
+ for _ in range(len(target["points"])):
276
+ tgt_p_d = {}
277
+ tgt_p_d["scores"] = torch.tensor(1.0000, device="cpu")
278
+ tgt_p_d["points"] = target["unnormalized_points"][_]
279
+ tgt_p_d["edges"] = target["edges"][_]
280
+ tgt_p_d["size"] = target["size"]
281
+ if "semantic_left_up" in target:
282
+ tgt_p_d["semantic_left_up"] = target["semantic_left_up"][_]
283
+ tgt_p_d["semantic_right_up"] = target["semantic_right_up"][_]
284
+ tgt_p_d["semantic_right_down"] = target["semantic_right_down"][_]
285
+ tgt_p_d["semantic_left_down"] = target["semantic_left_down"][_]
286
+ tgt_this_preds.append(tgt_p_d)
287
+ for __ in range(4):
288
+ adj = graph[0][tuple(tgt_p_d["points"].tolist())][__]
289
+ if adj != (-1, -1):
290
+ tgt_p_d1 = tgt_p_d
291
+ tgt_p_d2 = {}
292
+ indx = 99999
293
+ for ___, up in enumerate(target["unnormalized_points"].tolist()):
294
+ if abs(up[0] - adj[0]) + abs(up[1] - adj[1]) <= 2:
295
+ indx = ___
296
+ break
297
+ # assert indx != 99999
298
+ if indx == 99999: # No match found
299
+ # Log a warning or skip this iteration
300
+ print(f"Warning: No match found for adj {adj}")
301
+ continue # Skip to the next iteration
302
+ # tgt_p_d2['scores'] = torch.tensor(1.0000, device='cuda:0')
303
+ tgt_p_d2["points"] = target["unnormalized_points"][indx]
304
+ tgt_p_d2["edges"] = target["edges"][indx]
305
+ tgt_p_d2["size"] = target["size"]
306
+ if "semantic_left_up" in target:
307
+ tgt_p_d2["semantic_left_up"] = target["semantic_left_up"][indx]
308
+ tgt_p_d2["semantic_right_up"] = target["semantic_right_up"][indx]
309
+ tgt_p_d2["semantic_right_down"] = target["semantic_right_down"][indx]
310
+ tgt_p_d2["semantic_left_down"] = target["semantic_left_down"][indx]
311
+ tgt_e_l = (tgt_p_d1, tgt_p_d2)
312
+ if not edge_inside((tgt_p_d2, tgt_p_d1), tgt_this_edges):
313
+ tgt_this_edges.append(tgt_e_l)
314
+ tgt = [(tgt_this_preds, [], tgt_this_edges)]
315
+ target_d_rev, target_simple_cycles, target_results = get_cycle_basis_and_semantic((2, 999999, tgt))
316
+
317
+ # convert to coco format
318
+ polys_list = []
319
+ polys_semantic_list = []
320
+ output_json = []
321
+
322
+ image_width, image_height = target["size"][0].item(), target["size"][1].item()
323
+ filename = target["file_name"].split(".")[0]
324
+ img_id = int(target["image_id"])
325
+
326
+ img_dict = {}
327
+ img_dict["file_name"] = str(img_id).zfill(6) + ".png"
328
+ img_dict["id"] = img_id
329
+ img_dict["width"] = image_width
330
+ img_dict["height"] = image_height
331
+ save_dict = prepare_dict()
332
+
333
+ os.makedirs(os.path.join(save_dir, split), exist_ok=True)
334
+ os.makedirs(f"{save_dir}/{split}_jsons/", exist_ok=True)
335
+ json_path = f"{save_dir}/{split}_jsons/{str(img_id).zfill(6)}.json"
336
+
337
+ for instance_id, (poly, poly_cls) in enumerate(zip(target_simple_cycles, target_results)):
338
+ t = [(int(pt[0]), int(pt[1])) for pt in poly]
339
+ class_id = int(poly_cls)
340
+
341
+ polys_list.append(t)
342
+ polys_semantic_list.append(class_id)
343
+
344
+ poly_shapely = Polygon(t)
345
+ area = poly_shapely.area
346
+ coco_seg_poly = []
347
+ polygon = np.array(t)
348
+ poly_sorted = resort_corners(polygon)
349
+
350
+ for p in poly_sorted:
351
+ coco_seg_poly += list(p)
352
+
353
+ if area < 100:
354
+ continue
355
+
356
+ if class_id not in ID2CLASS:
357
+ print(f"Warning: Class ID {class_id} not found in ID2CLASS mapping. Skipping instance.")
358
+ continue
359
+
360
+ # Slightly wider bounding box
361
+ rectangle_shapely = poly_shapely.envelope
362
+ bb_x, bb_y = rectangle_shapely.exterior.xy
363
+ coco_bb = create_coco_bounding_box(bb_x, bb_y, image_width, image_height, bound_pad=2)
364
+
365
+ output_json.append(
366
+ {
367
+ "image_id": img_id,
368
+ "segmentation": [coco_seg_poly],
369
+ "category_id": class_id,
370
+ "id": instance_id,
371
+ "area": area,
372
+ "bbox": coco_bb,
373
+ "iscrowd": 0,
374
+ }
375
+ )
376
+
377
+ if vis_fp:
378
+ visualize_room_polygons(
379
+ polys_list,
380
+ polys_semantic_list,
381
+ image_size=image_width,
382
+ save_path=os.path.join(save_aux_dir, str(img_id).zfill(6) + ".png"),
383
+ )
384
+ room_map = plot_density_map(
385
+ img,
386
+ image_width,
387
+ polys_list,
388
+ polys_semantic_list,
389
+ plot_text=False,
390
+ )
391
+ cv2.imwrite(os.path.join(save_aux_dir, str(img_id).zfill(6) + "_density_map.png"), room_map)
392
+
393
+ print(f"Processed image {img_id} with {len(output_json)} instances.")
394
+ # print(f"Class: {target_results}")
395
+ # min_class_id = min(target_results)
396
+ # max_class_id = max(target_results)
397
+ # if max_class_id == 12:
398
+ # breakpoint()
399
+ # print(f"Min class ID: {min_class_id}, Max class ID: {max_class_id}")
400
+ save_dict["images"].append(img_dict)
401
+ save_dict["annotations"] += output_json
402
+ with open(json_path, "w") as json_file:
403
+ # Convert all numpy and torch types to native Python types for JSON serialization
404
+ def convert(o):
405
+ if isinstance(o, (np.integer, np.int32, np.int64)):
406
+ return int(o)
407
+ if isinstance(o, (np.floating, np.float32, np.float64)):
408
+ return float(o)
409
+ if isinstance(o, (np.ndarray,)):
410
+ return o.tolist()
411
+ if isinstance(o, torch.Tensor):
412
+ return o.item() if o.numel() == 1 else o.tolist()
413
+ return str(o)
414
+
415
+ json.dump(save_dict, json_file, default=convert)
416
+
417
+ # rename image file
418
+ shutil.copy(
419
+ os.path.join(source_data_path, split, filename + ".png"),
420
+ os.path.join(save_dir, split, str(img_id).zfill(6) + ".png"),
421
+ )
422
+
423
+ # Write mapping from source file name to target file name (safe for parallel)
424
+ mapping_line = f"{filename} {str(img_id).zfill(6)}\n"
425
+ # Each process writes to its own temp file
426
+ pid = os.getpid()
427
+ os.makedirs(os.path.join(save_dir, f"{split}_logs"), exist_ok=True)
428
+ mapping_file = os.path.join(save_dir, f"{split}_logs", f"{split}_file_mapping_{pid}.txt")
429
+ with open(mapping_file, "a") as f:
430
+ f.write(mapping_line)
431
+
432
+
433
+ if __name__ == "__main__":
434
+ args = get_args_parser().parse_args()
435
+ torch.set_printoptions(threshold=np.inf, linewidth=999999)
436
+ np.set_printoptions(threshold=np.inf, linewidth=999999)
437
+ gc.collect()
438
+ torch.cuda.empty_cache()
439
+
440
+ def wrapper(scene_id):
441
+ try:
442
+ image_set = dataset[scene_id]
443
+ except Exception as e:
444
+ print(f"Error processing scene {scene_id}: {e}. Skipping...")
445
+ return
446
+ process_floorplan(image_set, split, args.dataset_path, args.output_dir, save_aux_dir, vis_fp=scene_id < 100)
447
+
448
+ def worker_init(dataset_obj):
449
+ # Store dataset as global to avoid pickling issues
450
+ global dataset
451
+ dataset = dataset_obj
452
+
453
+ splits = ["train", "val", "test"]
454
+ for split in splits:
455
+ dataset = MyDataset(
456
+ args.dataset_path + f"/{split}",
457
+ args.dataset_path + "/annot_json" + f"/instances_{split}.json",
458
+ extract_roi=False,
459
+ )
460
+
461
+ save_aux_dir = os.path.join(args.output_dir, f"{split}_aux")
462
+ os.makedirs(save_aux_dir, exist_ok=True)
463
+
464
+ # for i, image_set in enumerate(tqdm(dataset)):
465
+ # save_aux_dir = os.path.join(args.output_dir, f"{split}_aux")
466
+ # os.makedirs(save_aux_dir, exist_ok=True)
467
+ # process_floorplan(image_set, split, args.dataset_path, args.output_dir, save_aux_dir, vis_fp=i < 100)
468
+
469
+ num_processes = 16
470
+ with Pool(num_processes, initializer=worker_init, initargs=(dataset,)) as p:
471
+ indices = range(len(dataset))
472
+ list(tqdm(p.imap(wrapper, indices), total=len(dataset)))
data_preprocess/raster2graph/dataset.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import os
4
+ from collections import defaultdict
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.multiprocessing
9
+ import torch.utils.data
10
+ import torchvision.transforms.functional as F
11
+ from PIL import Image
12
+ from torch.utils.data import Dataset
13
+ from util.data_utils import l1_dist
14
+ from util.graph_utils import graph_to_tensor
15
+ from util.image_id_dict import d
16
+ from util.mean_std import mean, std
17
+ from util.semantics_dict import semantics_dict
18
+
19
+ torch.multiprocessing.set_sharing_strategy("file_system")
20
+
21
+
22
+ class MyDataset(Dataset):
23
+ def __init__(self, img_path, annot_path, extract_roi, image_size=512):
24
+ self.img_path = img_path
25
+ self.quadtree_path = "/".join(img_path.split("/")[:-1]) + "/annot_npy"
26
+ self.mode = img_path.split("/")[-1]
27
+ self.image_size = image_size
28
+
29
+ # load annotation
30
+ with open(annot_path, "r") as f:
31
+ dataset = json.load(f)
32
+ # images
33
+ self.imgs = {}
34
+ for img in dataset["images"]:
35
+ self.imgs[img["id"]] = img
36
+ self.imgToAnns = defaultdict(list)
37
+ for ann in dataset["annotations"]:
38
+ self.imgToAnns[ann["image_id"]].append(ann)
39
+ self.ids = list(sorted(self.imgs.keys()))
40
+ if "0c-10-c468a57377ff8ef63d3b26a6d1fa-0002" in self.ids:
41
+ self.ids.remove("0c-10-c468a57377ff8ef63d3b26a6d1fa-0002")
42
+ if "0c-10-8486f08035ba152d5244ac54099c-0001" in self.ids:
43
+ self.ids.remove("0c-10-8486f08035ba152d5244ac54099c-0001")
44
+
45
+ def __getitem__(self, index):
46
+ img_id = self.ids[index]
47
+ img_file_name = self.imgs[img_id]["file_name"].replace(".jpg", ".png")
48
+ img = Image.open(os.path.join(self.img_path, img_file_name)).convert("RGB")
49
+ image_scale = self.image_size / img.size[0]
50
+ if img.size[0] != self.image_size or img.size[1] != self.image_size:
51
+ img = img.resize((self.image_size, self.image_size), Image.BILINEAR)
52
+
53
+ if 1:
54
+ # get structure annotations
55
+ anns = self.imgToAnns[img_id]
56
+ new_anns = []
57
+ for ann in anns:
58
+ new_ann = copy.deepcopy(ann)
59
+ new_ann["point"] = [int(ann["point"][0] * image_scale), int(ann["point"][1] * image_scale)]
60
+ new_anns.append(new_ann)
61
+ target = {"image_id": img_id, "annotations": new_anns}
62
+ orig_quadtree = np.load(
63
+ os.path.join(self.quadtree_path, img_file_name[:-4] + ".npy"), allow_pickle=True
64
+ ).item()["quatree"][0]
65
+ quadtree = {}
66
+ for k, v in orig_quadtree.items():
67
+ new_k = k
68
+ new_v = []
69
+ for pos in v:
70
+ new_pos = (int(pos[0] * image_scale), int(pos[1] * image_scale))
71
+ new_v.append(new_pos)
72
+ quadtree[new_k] = new_v
73
+
74
+ orig_graph = np.load(
75
+ os.path.join(self.quadtree_path, img_file_name[:-4] + ".npy"), allow_pickle=True
76
+ ).item()
77
+ del orig_graph["quatree"]
78
+ new_graph = {}
79
+ for k, v in orig_graph.items():
80
+ new_k = (int(k[0] * image_scale), int(k[1] * image_scale))
81
+ new_v = []
82
+ for adj in v:
83
+ if adj == (-1, -1):
84
+ new_v.append((-1, -1))
85
+ else:
86
+ new_v.append((int(adj[0] * image_scale), int(adj[1] * image_scale)))
87
+ new_graph[new_k] = new_v
88
+
89
+ target_layers = []
90
+ for layer, layer_points in quadtree.items():
91
+ target_layer = []
92
+ for layer_point in layer_points:
93
+ for target_i in target["annotations"]:
94
+ if l1_dist(target_i["point"], list(layer_point)) <= 2:
95
+ target_layer.append(target_i)
96
+ break
97
+ target_layers.extend(target_layer)
98
+ layer_indices = []
99
+ count = 0
100
+ for k, v in quadtree.items():
101
+ if k == 0:
102
+ layer_indices.append(0)
103
+ else:
104
+ layer_indices.append(count)
105
+ count += len(v)
106
+
107
+ image_id = torch.tensor([d[img_id]])
108
+
109
+ points = [obj["point"] for obj in target_layers]
110
+ points = torch.as_tensor(points, dtype=torch.int64).reshape(-1, 2)
111
+ edges = [obj["edge_code"] for obj in target_layers]
112
+ edges = torch.tensor(edges, dtype=torch.int64)
113
+
114
+ # get semantic annotations
115
+ semantic_left_up = [semantics_dict[obj["semantic"][0]] for obj in target_layers]
116
+ semantic_right_up = [semantics_dict[obj["semantic"][1]] for obj in target_layers]
117
+ semantic_right_down = [semantics_dict[obj["semantic"][2]] for obj in target_layers]
118
+ semantic_left_down = [semantics_dict[obj["semantic"][3]] for obj in target_layers]
119
+ semantic_left_up = torch.tensor(semantic_left_up, dtype=torch.int64)
120
+ semantic_right_up = torch.tensor(semantic_right_up, dtype=torch.int64)
121
+ semantic_right_down = torch.tensor(semantic_right_down, dtype=torch.int64)
122
+ semantic_left_down = torch.tensor(semantic_left_down, dtype=torch.int64)
123
+
124
+ # annotations
125
+ target = {}
126
+ target["edges"] = edges
127
+ target["file_name"] = img_file_name
128
+ target["image_id"] = image_id
129
+ target["size"] = torch.as_tensor([img.size[1], img.size[0]])
130
+
131
+ target["semantic_left_up"] = semantic_left_up
132
+ target["semantic_right_up"] = semantic_right_up
133
+ target["semantic_right_down"] = semantic_right_down
134
+ target["semantic_left_down"] = semantic_left_down
135
+
136
+ # get image
137
+ img = F.to_tensor(img)
138
+ img = F.normalize(img, mean=mean, std=std)
139
+ target["unnormalized_points"] = points
140
+ # normalize
141
+ points = points / torch.tensor([img.shape[2], img.shape[1]], dtype=torch.float32)
142
+ target["points"] = points
143
+ target["layer_indices"] = torch.tensor(layer_indices)
144
+
145
+ target["graph"] = graph_to_tensor(new_graph)
146
+
147
+ return img, target
148
+
149
+ def __len__(self):
150
+ return len(self.ids)
151
+
152
+
153
+ class MyDataset2(Dataset):
154
+ def __init__(self, img_path, annot_path, extract_roi, disable_sem_info=False):
155
+ self.disable_sem_info = disable_sem_info
156
+ self.img_path = img_path
157
+ self.quadtree_path = "/".join(img_path.split("/")[:-1]) + "/annotations_npy/" + img_path.split("/")[-1]
158
+ self.edgecode_path = "/".join(img_path.split("/")[:-1]) + "/annotations_edge/" + img_path.split("/")[-1]
159
+ self.mode = img_path.split("/")[-1]
160
+
161
+ available_ids = {int(x.replace(".npy", "")) for x in os.listdir(self.quadtree_path)}
162
+
163
+ # load annotation
164
+ with open(annot_path, "r") as f:
165
+ dataset = json.load(f)
166
+ # images
167
+ self.imgs = {}
168
+ for img in dataset["images"]:
169
+ if img["id"] not in available_ids:
170
+ continue
171
+ self.imgs[img["id"]] = img
172
+ self.imgToAnns = defaultdict(list)
173
+ for ann in dataset["annotations"]:
174
+ if ann["image_id"] not in available_ids:
175
+ continue
176
+ self.imgToAnns[ann["image_id"]].append(ann)
177
+ self.ids = list(sorted(self.imgs.keys()))
178
+
179
+ def __getitem__(self, index):
180
+ img_id = self.ids[index]
181
+ img_file_name = self.imgs[int(img_id)]["file_name"]
182
+ img = Image.open(os.path.join(self.img_path, img_file_name)).convert("RGB")
183
+
184
+ if 1:
185
+ # get structure annotations
186
+ # anns = self.imgToAnns[int(img_id)]
187
+
188
+ data = np.load(os.path.join(self.quadtree_path, img_file_name[:-4] + ".npy"), allow_pickle=True).item()
189
+ orig_quadtree = data["quadtree"]
190
+ orig_graph = data["graph"]
191
+ image_points = data["points"]
192
+
193
+ new_anns = []
194
+ for pt in image_points:
195
+ new_ann = {
196
+ "point": [int(pt[0]), int(pt[1])],
197
+ }
198
+ new_anns.append(new_ann)
199
+ target = {"image_id": img_id, "annotations": new_anns}
200
+
201
+ quadtree = {}
202
+ for k, v in orig_quadtree.items():
203
+ new_k = k
204
+ new_v = []
205
+ for pos in v:
206
+ new_pos = (int(pos[0]), int(pos[1]))
207
+ new_v.append(new_pos)
208
+ quadtree[new_k] = new_v
209
+
210
+ new_graph = {}
211
+ for k, v in orig_graph.items():
212
+ new_k = (int(k[0]), int(k[1]))
213
+ new_v = []
214
+ for adj in v:
215
+ if adj == (-1, -1):
216
+ new_v.append((-1, -1))
217
+ else:
218
+ new_v.append((int(adj[0]), int(adj[1])))
219
+ new_graph[new_k] = new_v
220
+
221
+ target_layers = []
222
+ for layer, layer_points in quadtree.items():
223
+ target_layer = []
224
+ for layer_point in layer_points:
225
+ for target_i in target["annotations"]:
226
+ if l1_dist(target_i["point"], list(layer_point)) <= 2:
227
+ target_layer.append(target_i)
228
+ break
229
+ target_layers.extend(target_layer)
230
+ layer_indices = []
231
+ count = 0
232
+ for k, v in quadtree.items():
233
+ if k == 0:
234
+ layer_indices.append(0)
235
+ else:
236
+ layer_indices.append(count)
237
+ count += len(v)
238
+
239
+ image_id = torch.tensor([int(img_id)])
240
+
241
+ points = [obj["point"] for obj in target_layers]
242
+ with open(os.path.join(self.edgecode_path, img_file_name[:-4] + ".json"), "r") as f:
243
+ edge2code = json.load(f)
244
+ edge2code = {
245
+ tuple(map(lambda x: int(float(x)), key.strip("()").split(", "))): value
246
+ for key, value in edge2code.items()
247
+ }
248
+
249
+ edges = [edge2code[(int(pt[0]), int(pt[1]))] for pt in points]
250
+ points = torch.as_tensor(points, dtype=torch.int64).reshape(-1, 2)
251
+ edges = torch.tensor(edges, dtype=torch.int64)
252
+
253
+ # annotations
254
+ target = {}
255
+ target["edges"] = edges
256
+ target["image_id"] = image_id
257
+ target["file_name"] = img_file_name
258
+ target["size"] = torch.as_tensor([img.size[1], img.size[0]])
259
+
260
+ # get semantic annotations
261
+ if not self.disable_sem_info:
262
+ semantic_left_up = [semantics_dict[obj["semantic"][0]] for obj in target_layers]
263
+ semantic_right_up = [semantics_dict[obj["semantic"][1]] for obj in target_layers]
264
+ semantic_right_down = [semantics_dict[obj["semantic"][2]] for obj in target_layers]
265
+ semantic_left_down = [semantics_dict[obj["semantic"][3]] for obj in target_layers]
266
+ semantic_left_up = torch.tensor(semantic_left_up, dtype=torch.int64)
267
+ semantic_right_up = torch.tensor(semantic_right_up, dtype=torch.int64)
268
+ semantic_right_down = torch.tensor(semantic_right_down, dtype=torch.int64)
269
+ semantic_left_down = torch.tensor(semantic_left_down, dtype=torch.int64)
270
+
271
+ target["semantic_left_up"] = semantic_left_up
272
+ target["semantic_right_up"] = semantic_right_up
273
+ target["semantic_right_down"] = semantic_right_down
274
+ target["semantic_left_down"] = semantic_left_down
275
+
276
+ # get image
277
+ img = F.to_tensor(img)
278
+ img = F.normalize(img, mean=mean, std=std)
279
+ target["unnormalized_points"] = points
280
+ # normalize
281
+ points = points / torch.tensor([img.shape[2], img.shape[1]], dtype=torch.float32)
282
+ target["points"] = points
283
+ target["layer_indices"] = torch.tensor(layer_indices)
284
+
285
+ # padding (-1,-1) if not enough 4 neighbors
286
+ for pt, neighbors in new_graph.items():
287
+ if len(neighbors) < 4:
288
+ new_graph[pt].extend([(-1, -1)] * (4 - len(neighbors)))
289
+ elif len(neighbors) > 4:
290
+ new_graph[pt] = neighbors[:4]
291
+ target["graph"] = graph_to_tensor(new_graph)
292
+
293
+ return img, target
294
+
295
+ def __len__(self):
296
+ return len(self.ids)
data_preprocess/raster2graph/image_process.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+
5
+ import numpy as np
6
+ from PIL import Image
7
+ from tqdm import tqdm
8
+
9
+ parser = argparse.ArgumentParser("Preprocess LIFULL HOMES DATA (HIGH RESOLUTION) Dataset")
10
+ parser.add_argument("--data_root", type=str, default=r"R2G_hr_dataset/", help="path to the root folder of the dataset")
11
+ args = parser.parse_args()
12
+
13
+ SIZE = 512
14
+ MARGIN = 64
15
+ np.set_printoptions(threshold=np.inf, linewidth=999999)
16
+
17
+ # original_images_path = r'E:/LIFULL HOMES DATA (HIGH RESOLUTION)/photo-rent-madori-full-00'
18
+ original_images_path = args.data_root
19
+
20
+ with open(f"{args.data_root}/annot_json/instances_train.json", mode="r") as f_train:
21
+ train_jpgs = [_["file_name"] for _ in json.load(f_train)["images"]]
22
+ with open(f"{args.data_root}/annot_json/instances_val.json", mode="r") as f_val:
23
+ val_jpgs = [_["file_name"] for _ in json.load(f_val)["images"]]
24
+ with open(f"{args.data_root}/instances_test.json", mode="r") as f_test:
25
+ test_jpgs = [_["file_name"] for _ in json.load(f_test)["images"]]
26
+ jpgs = {"train": train_jpgs, "val": val_jpgs, "test": test_jpgs}
27
+
28
+ start_idx = 0
29
+ for mode in ["train", "val", "test"]:
30
+ output_dir = "./" + mode
31
+ os.makedirs(output_dir, exist_ok=True)
32
+ for fnames in [jpgs[mode]]:
33
+ for i in tqdm(range(len(fnames))):
34
+ fn = fnames[i].replace(".jpg", "")
35
+ if os.path.exists(os.path.join(f"{args.data_root}/annot_npy", fn + ".npy")) and os.path.exists(
36
+ os.path.join(f"{args.data_root}/original_vector_boundary", fn + ".npy")
37
+ ):
38
+ img_original = Image.open(os.path.join(original_images_path, fn.replace("-", "/") + ".jpg"))
39
+ boundary_path = os.path.join(f"{args.data_root}/original_vector_boundary", fn + ".npy")
40
+ boundary = np.load(boundary_path, allow_pickle=True).item()
41
+ x_min = boundary["x_min"]
42
+ x_max = boundary["x_max"]
43
+ y_min = boundary["y_min"]
44
+ y_max = boundary["y_max"]
45
+ width = x_max - x_min
46
+ mid_width = (x_max + x_min) / 2
47
+ height = y_max - y_min
48
+ mid_height = (y_max + y_min) / 2
49
+ if width > height:
50
+ scale = (SIZE - 2 * MARGIN) / width
51
+ else:
52
+ scale = (SIZE - 2 * MARGIN) / height
53
+ # print(x_min, y_min, x_max, y_max, width, height, scale)
54
+
55
+ original_width, original_height = img_original.size
56
+ new_width = int(original_width * scale)
57
+ new_height = int(original_height * scale)
58
+ scaled_image = img_original.resize((new_width, new_height), Image.Resampling.LANCZOS)
59
+ canvas = Image.new("RGB", (512, 512), (255, 255, 255))
60
+ # print(new_width, new_height)
61
+ x_topleft_offset = int(512 / 2 - mid_width * scale)
62
+ y_topleft_offset = int(512 / 2 - mid_height * scale)
63
+ canvas.paste(scaled_image, (x_topleft_offset, y_topleft_offset))
64
+
65
+ canvas.save(os.path.join(output_dir, fn + ".png"))
66
+
67
+ start_idx += 1
data_preprocess/raster2graph/util/data_utils.py ADDED
@@ -0,0 +1,966 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import random
3
+
4
+ import torch
5
+ from util.edges_utils import get_edges_alldirections_rev
6
+ from util.math_utils import clip
7
+ from util.mean_std import mean, std
8
+
9
+
10
+ def data_to_cuda(samples, targets):
11
+ return samples.to(torch.device("cuda")), [
12
+ {k: v if isinstance(v, str) else v.to(torch.device("cuda")) for k, v in t.items()} for t in targets
13
+ ]
14
+
15
+
16
+ def get_random_layer_targets(targets, gt_layer):
17
+ random_targets = []
18
+ for batch_i, target_i in enumerate(targets):
19
+ random_layer_targets_i = copy.deepcopy(target_i)
20
+ if gt_layer[batch_i] != len(random_layer_targets_i["layer_indices"]) - 1:
21
+ start = random_layer_targets_i["layer_indices"][gt_layer[batch_i]].item()
22
+ end = random_layer_targets_i["layer_indices"][gt_layer[batch_i] + 1].item()
23
+ else:
24
+ start = random_layer_targets_i["layer_indices"][gt_layer[batch_i]].item()
25
+ end = len(random_layer_targets_i["points"])
26
+ random_points_i = random_layer_targets_i["points"][start:end, :]
27
+ random_edges_i = random_layer_targets_i["edges"][start:end]
28
+ random_unnormalized_points_i = random_layer_targets_i["unnormalized_points"][start:end, :]
29
+ random_layer_targets_i["points"] = random_points_i
30
+ random_layer_targets_i["edges"] = random_edges_i
31
+ random_layer_targets_i["unnormalized_points"] = random_unnormalized_points_i
32
+ del random_layer_targets_i["layer_indices"]
33
+ random_targets.append(random_layer_targets_i)
34
+ return random_targets
35
+
36
+
37
+ def random_layers(targets):
38
+ return [random.randint(0, len(targets[i]["layer_indices"]) - 1) for i in range(len(targets))]
39
+
40
+
41
+ def get_given_layers_random_region(targets, graphs):
42
+ random_regions = []
43
+ for bs_i in range(len(targets)):
44
+ # target
45
+ targets_i = targets[bs_i]
46
+ graphs_i = graphs[bs_i]
47
+ # level 0: start
48
+ start_i = tuple(targets_i["unnormalized_points"][0].tolist())
49
+
50
+ # sampled prob: for a neighborhood, each node is sampled by this probability
51
+ # sampled_prob = 0.0001
52
+ # sampled_prob = random.random()
53
+ sampled_prob = 0.5
54
+ # sampled_prob = 1
55
+
56
+ # sampled nodes
57
+ sampled_points = {}
58
+ for point_tensor in targets_i["unnormalized_points"]:
59
+ pos = tuple(point_tensor.tolist())
60
+ sampled_points[pos] = 0
61
+ # edges of sampled nodes
62
+ sampled_edges = []
63
+
64
+ # nodes number of subgraph
65
+ # sampled_amount = random.randint(0, len(sampled_points) + 2)
66
+ # if sampled_amount in [len(sampled_points) + 1]:
67
+ # sampled_amount = 0
68
+ # if sampled_amount in [len(sampled_points) + 2]:
69
+ # sampled_amount = len(sampled_points)
70
+ sampled_amount = random.randint(0, len(sampled_points)) # TODO: 1~len(sampled_points)
71
+
72
+ # Note that when sampled_prob = 1, the number of sampled nodes must be in 'layer_indices' or be the total number of points to ensure that the entire layers is sampled.
73
+ # equal to BFS
74
+ if sampled_prob == 1:
75
+ l = targets_i["layer_indices"].tolist()
76
+ l.append(len(sampled_points))
77
+ l.append(0)
78
+ l.append(len(sampled_points))
79
+ sampled_amount = l[random.randint(0, len(l) - 1)]
80
+
81
+ # start sampling
82
+ if sampled_amount == 0:
83
+ random_regions.append((sampled_points, sampled_edges))
84
+ continue
85
+ sampled_points[start_i] = 1
86
+ if sampled_amount == 1:
87
+ random_regions.append((sampled_points, sampled_edges))
88
+ continue
89
+
90
+ max_iterations = max(1000, 10 * sampled_amount) # Ensure at least 1000 iterations
91
+ iteration_count = 0
92
+ while sum(sampled_points.values()) < sampled_amount:
93
+ iteration_count += 1
94
+ if iteration_count > max_iterations:
95
+ print("Reached maximum iterations, breaking to avoid infinite loop.")
96
+ break
97
+ all_sampled_points = set([k for k, v in sampled_points.items() if v == 1])
98
+ all_sampled_points_adjs = set()
99
+ for sampled_point in all_sampled_points:
100
+ adj = set([(int(x[0]), int(x[1])) for x in graphs_i[sampled_point]])
101
+ all_sampled_points_adjs = all_sampled_points_adjs.union(adj)
102
+
103
+ if (-1, -1) in all_sampled_points_adjs:
104
+ all_sampled_points_adjs.remove((-1, -1))
105
+ all_sampled_points_adjs = list(all_sampled_points_adjs.difference(all_sampled_points))
106
+
107
+ if not all_sampled_points_adjs: # If no more adjacent points to sample, break
108
+ print("No more adjacent points to sample, breaking the loop.")
109
+ break
110
+
111
+ # shuffle the last layer to let it uniform (no bias of sample order)
112
+ random.shuffle(all_sampled_points_adjs)
113
+ # determine whether to sample nodes in each neighborhood based on probability
114
+ for all_sampled_points_adj_index, all_sampled_points_adj in enumerate(all_sampled_points_adjs):
115
+ all_sampled_points = set([k for k, v in sampled_points.items() if v == 1])
116
+ if sum(sampled_points.values()) == sampled_amount:
117
+ break
118
+ else:
119
+ if 1:
120
+ if random.random() < sampled_prob:
121
+ sampled_points[all_sampled_points_adj] = 1
122
+ # sample edges
123
+ all_pos1s = graphs_i[all_sampled_points_adj]
124
+ pos2 = all_sampled_points_adj
125
+ for pos1 in all_pos1s:
126
+ if pos1 in all_sampled_points:
127
+ sampled_edges.append((pos1, pos2))
128
+ else:
129
+ sampled_points[all_sampled_points_adj] = 0
130
+ random_regions.append((sampled_points, sampled_edges))
131
+ return random_regions
132
+
133
+
134
+ def get_random_region_targets(given_layers, graphs, targets):
135
+ random_region_targets = []
136
+ for bs_i in range(len(targets)):
137
+ random_region_target = {}
138
+ targets_i = targets[bs_i]
139
+ graphs_i = graphs[bs_i]
140
+ given_layers_i = given_layers[bs_i]
141
+ sampled_points_i, sampled_edges_i = given_layers_i
142
+
143
+ if sum(sampled_points_i.values()) == 0:
144
+ random_region_target["edges"] = targets_i["edges"][:1]
145
+
146
+ if "semantic_left_up" in targets_i:
147
+ random_region_target["semantic_left_up"] = targets_i["semantic_left_up"][:1]
148
+ random_region_target["semantic_right_up"] = targets_i["semantic_right_up"][:1]
149
+ random_region_target["semantic_right_down"] = targets_i["semantic_right_down"][:1]
150
+ random_region_target["semantic_left_down"] = targets_i["semantic_left_down"][:1]
151
+
152
+ random_region_target["image_id"] = targets_i["image_id"]
153
+ random_region_target["size"] = targets_i["size"]
154
+ random_region_target["unnormalized_points"] = targets_i["unnormalized_points"][:1]
155
+ random_region_target["points"] = targets_i["points"][:1]
156
+ random_region_target["last_edges"] = torch.zeros(
157
+ (1,), dtype=targets_i["edges"].dtype, device=targets_i["edges"].device
158
+ )
159
+ random_region_target["this_edges"] = torch.zeros(
160
+ (1,), dtype=targets_i["edges"].dtype, device=targets_i["edges"].device
161
+ )
162
+ random_region_targets.append(random_region_target)
163
+ elif 1 <= sum(sampled_points_i.values()) <= len(sampled_points_i) - 1:
164
+ sampled_points_i_given = set([k for k, v in sampled_points_i.items() if v == 1])
165
+ unnormalized_points = []
166
+ for point, sampled_or_not in sampled_points_i.items():
167
+ if sampled_or_not == 0:
168
+ adjs = graphs_i[point]
169
+ for adj in adjs:
170
+ if adj in sampled_points_i_given:
171
+ unnormalized_points.append(point)
172
+ break
173
+
174
+ if len(unnormalized_points) == 0:
175
+ random_region_target["edges"] = targets_i["edges"][:1]
176
+
177
+ if "semantic_left_up" in targets_i:
178
+ random_region_target["semantic_left_up"] = targets_i["semantic_left_up"][:1]
179
+ random_region_target["semantic_right_up"] = targets_i["semantic_right_up"][:1]
180
+ random_region_target["semantic_right_down"] = targets_i["semantic_right_down"][:1]
181
+ random_region_target["semantic_left_down"] = targets_i["semantic_left_down"][:1]
182
+
183
+ random_region_target["image_id"] = targets_i["image_id"]
184
+ random_region_target["size"] = targets_i["size"]
185
+ random_region_target["unnormalized_points"] = targets_i["unnormalized_points"][:1]
186
+ random_region_target["points"] = targets_i["points"][:1]
187
+ random_region_target["last_edges"] = torch.zeros(
188
+ (1,), dtype=targets_i["edges"].dtype, device=targets_i["edges"].device
189
+ )
190
+ random_region_target["this_edges"] = torch.zeros(
191
+ (1,), dtype=targets_i["edges"].dtype, device=targets_i["edges"].device
192
+ )
193
+ random_region_targets.append(random_region_target)
194
+ continue
195
+
196
+ indices_for_semantic = []
197
+ for unnormalized_point in unnormalized_points:
198
+ for ind, every_point in enumerate(targets_i["unnormalized_points"]):
199
+ every_point = tuple(every_point.tolist())
200
+ if (
201
+ abs(every_point[0] - unnormalized_point[0]) <= 2
202
+ and abs(every_point[1] - unnormalized_point[1]) <= 2
203
+ ):
204
+ indices_for_semantic.append(ind)
205
+ # assert len(unnormalized_points) == len(indices_for_semantic)
206
+ semantic_left_up = []
207
+ semantic_right_up = []
208
+ semantic_right_down = []
209
+ semantic_left_down = []
210
+
211
+ if "semantic_left_up" in targets_i:
212
+ for ind in indices_for_semantic:
213
+ semantic_left_up.append(targets_i["semantic_left_up"][ind].item())
214
+ semantic_right_up.append(targets_i["semantic_right_up"][ind].item())
215
+ semantic_right_down.append(targets_i["semantic_right_down"][ind].item())
216
+ semantic_left_down.append(targets_i["semantic_left_down"][ind].item())
217
+
218
+ edges = []
219
+ for unnormalized_point in unnormalized_points:
220
+ edge = ""
221
+ adjs = graphs_i[unnormalized_point]
222
+ for adj in adjs:
223
+ if adj != (-1, -1):
224
+ edge += "1"
225
+ else:
226
+ edge += "0"
227
+ edge = get_edges_alldirections_rev(edge)
228
+ edges.append(edge)
229
+ last_edges = []
230
+ for unnormalized_point in unnormalized_points:
231
+ last_edge = ""
232
+ adjs = graphs_i[unnormalized_point]
233
+ for adj in adjs:
234
+ if adj in sampled_points_i_given:
235
+ last_edge += "1"
236
+ else:
237
+ last_edge += "0"
238
+ last_edge = get_edges_alldirections_rev(last_edge)
239
+ last_edges.append(last_edge)
240
+ this_edges = []
241
+ for unnormalized_point in unnormalized_points:
242
+ this_edge = ""
243
+ adjs = graphs_i[unnormalized_point]
244
+ for adj in adjs:
245
+ if adj in unnormalized_points:
246
+ this_edge += "1"
247
+ else:
248
+ this_edge += "0"
249
+ this_edge = get_edges_alldirections_rev(this_edge)
250
+ this_edges.append(this_edge)
251
+
252
+ random_region_target["edges"] = torch.tensor(
253
+ edges, dtype=targets_i["edges"].dtype, device=targets_i["edges"].device
254
+ )
255
+
256
+ if "semantic_left_up" in targets_i:
257
+ random_region_target["semantic_left_up"] = torch.tensor(
258
+ semantic_left_up,
259
+ dtype=targets_i["semantic_left_up"].dtype,
260
+ device=targets_i["semantic_left_up"].device,
261
+ )
262
+ random_region_target["semantic_right_up"] = torch.tensor(
263
+ semantic_right_up,
264
+ dtype=targets_i["semantic_right_up"].dtype,
265
+ device=targets_i["semantic_right_up"].device,
266
+ )
267
+ random_region_target["semantic_right_down"] = torch.tensor(
268
+ semantic_right_down,
269
+ dtype=targets_i["semantic_right_down"].dtype,
270
+ device=targets_i["semantic_right_down"].device,
271
+ )
272
+ random_region_target["semantic_left_down"] = torch.tensor(
273
+ semantic_left_down,
274
+ dtype=targets_i["semantic_left_down"].dtype,
275
+ device=targets_i["semantic_left_down"].device,
276
+ )
277
+
278
+ random_region_target["image_id"] = targets_i["image_id"]
279
+ random_region_target["size"] = targets_i["size"]
280
+
281
+ # NEW
282
+ # if len(unnormalized_points) == 0:
283
+ # print("Warning: unnormalized_points is empty. Initializing to default value.")
284
+ # unnormalized_points = torch.zeros((1, 2), dtype=targets_i['unnormalized_points'].dtype,
285
+ # device=targets_i['unnormalized_points'].device)
286
+ # else:
287
+ # random_region_target['unnormalized_points'] = torch.tensor(unnormalized_points,
288
+ # dtype=targets_i['unnormalized_points'].dtype,
289
+ # device=targets_i['unnormalized_points'].device)
290
+
291
+ random_region_target["unnormalized_points"] = torch.tensor(
292
+ unnormalized_points,
293
+ dtype=targets_i["unnormalized_points"].dtype,
294
+ device=targets_i["unnormalized_points"].device,
295
+ )
296
+ random_region_target["points"] = (
297
+ torch.tensor(unnormalized_points, dtype=targets_i["points"].dtype, device=targets_i["points"].device)
298
+ / targets_i["size"]
299
+ )
300
+ random_region_target["last_edges"] = torch.tensor(
301
+ last_edges, dtype=targets_i["edges"].dtype, device=targets_i["edges"].device
302
+ )
303
+ random_region_target["this_edges"] = torch.tensor(
304
+ this_edges, dtype=targets_i["edges"].dtype, device=targets_i["edges"].device
305
+ )
306
+ random_region_targets.append(random_region_target)
307
+ else:
308
+ random_region_target["edges"] = 16 * torch.ones(
309
+ targets_i["edges"][:1].shape, dtype=targets_i["edges"].dtype, device=targets_i["edges"].device
310
+ )
311
+
312
+ if "semantic_left_up" in targets_i:
313
+ random_region_target["semantic_left_up"] = 11 * torch.ones(
314
+ targets_i["semantic_left_up"][:1].shape,
315
+ dtype=targets_i["semantic_left_up"].dtype,
316
+ device=targets_i["semantic_left_up"].device,
317
+ )
318
+ random_region_target["semantic_right_up"] = 11 * torch.ones(
319
+ targets_i["semantic_right_up"][:1].shape,
320
+ dtype=targets_i["semantic_right_up"].dtype,
321
+ device=targets_i["semantic_right_up"].device,
322
+ )
323
+ random_region_target["semantic_right_down"] = 11 * torch.ones(
324
+ targets_i["semantic_right_down"][:1].shape,
325
+ dtype=targets_i["semantic_right_down"].dtype,
326
+ device=targets_i["semantic_right_down"].device,
327
+ )
328
+ random_region_target["semantic_left_down"] = 11 * torch.ones(
329
+ targets_i["semantic_left_down"][:1].shape,
330
+ dtype=targets_i["semantic_left_down"].dtype,
331
+ device=targets_i["semantic_left_down"].device,
332
+ )
333
+
334
+ random_region_target["image_id"] = targets_i["image_id"]
335
+ random_region_target["size"] = targets_i["size"]
336
+ random_region_target["unnormalized_points"] = 505 * torch.ones(
337
+ targets_i["unnormalized_points"][:1].shape,
338
+ dtype=targets_i["unnormalized_points"][:1].dtype,
339
+ device=targets_i["unnormalized_points"][:1].device,
340
+ )
341
+ random_region_target["points"] = (
342
+ 505
343
+ * torch.ones(
344
+ targets_i["unnormalized_points"][:1].shape,
345
+ dtype=targets_i["points"][:1].dtype,
346
+ device=targets_i["points"][:1].device,
347
+ )
348
+ ) / targets_i["size"]
349
+ random_region_target["last_edges"] = 16 * torch.ones(
350
+ (1,), dtype=targets_i["edges"].dtype, device=targets_i["edges"].device
351
+ )
352
+ random_region_target["this_edges"] = 16 * torch.ones(
353
+ (1,), dtype=targets_i["edges"].dtype, device=targets_i["edges"].device
354
+ )
355
+ random_region_targets.append(random_region_target)
356
+
357
+ return random_region_targets
358
+
359
+
360
+ def random_pertubation(sampled_points_i, sampled_edges_i):
361
+ random_pertube_map = {}
362
+ sigma = 2
363
+ pertube_threshold = 5
364
+ for sampled_point in sampled_points_i:
365
+ random_pertube_map[sampled_point] = (
366
+ sampled_point[0] + clip(int(random.gauss(0, sigma)), -1 * pertube_threshold, pertube_threshold),
367
+ sampled_point[1] + clip(int(random.gauss(0, sigma)), -1 * pertube_threshold, pertube_threshold),
368
+ )
369
+ new_sampled_points_i = {}
370
+ new_sampled_edges_i = []
371
+ for sampled_point in sampled_points_i:
372
+ new_sampled_points_i[random_pertube_map[sampled_point]] = sampled_points_i[sampled_point]
373
+ for pos1, pos2 in sampled_edges_i:
374
+ new_sampled_edges_i.append((random_pertube_map[pos1], random_pertube_map[pos2]))
375
+ return new_sampled_points_i, new_sampled_edges_i
376
+
377
+
378
+ def draw_given_layers_on_tensors_random_region(given_layers, tensors, graphs):
379
+ """draw 9*9 yellow squares and width 2 blue lines"""
380
+ tensors_list = []
381
+ unnormalized_list = []
382
+ for i in range(len(given_layers)):
383
+ temp_tensor = tensors[i]
384
+
385
+ temp_tensor_0 = (temp_tensor[0] * std[0] + mean[0]) * 255
386
+ temp_tensor_1 = (temp_tensor[1] * std[1] + mean[1]) * 255
387
+ temp_tensor_2 = (temp_tensor[2] * std[2] + mean[2]) * 255
388
+
389
+ rectangle_radius = 5
390
+
391
+ # end sign
392
+ endsign = (505, 505)
393
+ valid_violet_endsign_up = endsign[1] - rectangle_radius
394
+ valid_violet_endsign_down = endsign[1] + rectangle_radius
395
+ valid_violet_endsign_left = endsign[0] - rectangle_radius
396
+ valid_violet_endsign_right = endsign[0] + rectangle_radius
397
+ temp_tensor_0[
398
+ valid_violet_endsign_up : valid_violet_endsign_down + 1,
399
+ valid_violet_endsign_left : valid_violet_endsign_right + 1,
400
+ ] = 255
401
+ temp_tensor_1[
402
+ valid_violet_endsign_up : valid_violet_endsign_down + 1,
403
+ valid_violet_endsign_left : valid_violet_endsign_right + 1,
404
+ ] = 0
405
+ temp_tensor_2[
406
+ valid_violet_endsign_up : valid_violet_endsign_down + 1,
407
+ valid_violet_endsign_left : valid_violet_endsign_right + 1,
408
+ ] = 255
409
+
410
+ sampled_points_i, sampled_edges_i = given_layers[i]
411
+ sampled_points_i, sampled_edges_i = random_pertubation(sampled_points_i, sampled_edges_i)
412
+
413
+ given_points = [k for k, v in sampled_points_i.items() if v == 1]
414
+
415
+ for j, pos in enumerate(given_points):
416
+ valid_yellow_pos_up = int(pos[1] - rectangle_radius) if (pos[1] - rectangle_radius) >= 0 else 0
417
+ valid_yellow_pos_down = (
418
+ int(pos[1] + rectangle_radius)
419
+ if (pos[1] + rectangle_radius) < temp_tensor.shape[2]
420
+ else temp_tensor.shape[2] - 1
421
+ )
422
+ valid_yellow_pos_left = int(pos[0] - rectangle_radius) if (pos[0] - rectangle_radius) >= 0 else 0
423
+ valid_yellow_pos_right = (
424
+ int(pos[0] + rectangle_radius)
425
+ if (pos[0] + rectangle_radius) < temp_tensor.shape[1]
426
+ else temp_tensor.shape[1] - 1
427
+ )
428
+
429
+ temp_tensor_0[
430
+ valid_yellow_pos_up : valid_yellow_pos_down + 1, valid_yellow_pos_left : valid_yellow_pos_right + 1
431
+ ] = 255
432
+ temp_tensor_1[
433
+ valid_yellow_pos_up : valid_yellow_pos_down + 1, valid_yellow_pos_left : valid_yellow_pos_right + 1
434
+ ] = 255
435
+ temp_tensor_2[
436
+ valid_yellow_pos_up : valid_yellow_pos_down + 1, valid_yellow_pos_left : valid_yellow_pos_right + 1
437
+ ] = 0
438
+
439
+ # draw blue lines
440
+ line_width = 2
441
+ for edge in sampled_edges_i:
442
+ pos1 = (int(edge[0][0]), int(edge[0][1]))
443
+ pos2 = (int(edge[1][0]), int(edge[1][1]))
444
+ if abs(pos1[0] - pos2[0]) < abs(pos1[1] - pos2[1]):
445
+ if pos1[1] > pos2[1]:
446
+ temp_tensor_0[
447
+ pos2[1] : pos1[1] + 1,
448
+ int((pos1[0] + pos2[0]) / 2) - int(line_width / 2) : int((pos1[0] + pos2[0]) / 2)
449
+ + int(line_width / 2)
450
+ + 1,
451
+ ] = 0
452
+ temp_tensor_1[
453
+ pos2[1] : pos1[1] + 1,
454
+ int((pos1[0] + pos2[0]) / 2) - int(line_width / 2) : int((pos1[0] + pos2[0]) / 2)
455
+ + int(line_width / 2)
456
+ + 1,
457
+ ] = 0
458
+ temp_tensor_2[
459
+ pos2[1] : pos1[1] + 1,
460
+ int((pos1[0] + pos2[0]) / 2) - int(line_width / 2) : int((pos1[0] + pos2[0]) / 2)
461
+ + int(line_width / 2)
462
+ + 1,
463
+ ] = 255
464
+ else:
465
+ temp_tensor_0[
466
+ pos1[1] : pos2[1] + 1,
467
+ int((pos2[0] + pos1[0]) / 2) - int(line_width / 2) : int((pos2[0] + pos1[0]) / 2)
468
+ + int(line_width / 2)
469
+ + 1,
470
+ ] = 0
471
+ temp_tensor_1[
472
+ pos1[1] : pos2[1] + 1,
473
+ int((pos2[0] + pos1[0]) / 2) - int(line_width / 2) : int((pos2[0] + pos1[0]) / 2)
474
+ + int(line_width / 2)
475
+ + 1,
476
+ ] = 0
477
+ temp_tensor_2[
478
+ pos1[1] : pos2[1] + 1,
479
+ int((pos2[0] + pos1[0]) / 2) - int(line_width / 2) : int((pos2[0] + pos1[0]) / 2)
480
+ + int(line_width / 2)
481
+ + 1,
482
+ ] = 255
483
+ else:
484
+ if pos1[0] > pos2[0]:
485
+ temp_tensor_0[
486
+ int((pos1[1] + pos2[1]) / 2) - int(line_width / 2) : int((pos1[1] + pos2[1]) / 2)
487
+ + int(line_width / 2)
488
+ + 1,
489
+ pos2[0] : pos1[0] + 1,
490
+ ] = 0
491
+ temp_tensor_1[
492
+ int((pos1[1] + pos2[1]) / 2) - int(line_width / 2) : int((pos1[1] + pos2[1]) / 2)
493
+ + int(line_width / 2)
494
+ + 1,
495
+ pos2[0] : pos1[0] + 1,
496
+ ] = 0
497
+ temp_tensor_2[
498
+ int((pos1[1] + pos2[1]) / 2) - int(line_width / 2) : int((pos1[1] + pos2[1]) / 2)
499
+ + int(line_width / 2)
500
+ + 1,
501
+ pos2[0] : pos1[0] + 1,
502
+ ] = 255
503
+ else:
504
+ temp_tensor_0[
505
+ int((pos2[1] + pos1[1]) / 2) - int(line_width / 2) : int((pos2[1] + pos1[1]) / 2)
506
+ + int(line_width / 2)
507
+ + 1,
508
+ pos1[0] : pos2[0] + 1,
509
+ ] = 0
510
+ temp_tensor_1[
511
+ int((pos2[1] + pos1[1]) / 2) - int(line_width / 2) : int((pos2[1] + pos1[1]) / 2)
512
+ + int(line_width / 2)
513
+ + 1,
514
+ pos1[0] : pos2[0] + 1,
515
+ ] = 0
516
+ temp_tensor_2[
517
+ int((pos2[1] + pos1[1]) / 2) - int(line_width / 2) : int((pos2[1] + pos1[1]) / 2)
518
+ + int(line_width / 2)
519
+ + 1,
520
+ pos1[0] : pos2[0] + 1,
521
+ ] = 255
522
+
523
+ unnormalized = torch.stack((temp_tensor_0, temp_tensor_1, temp_tensor_2), dim=0)
524
+ unnormalized_list.append(unnormalized)
525
+
526
+ temp_tensor_0_renorm = ((temp_tensor_0 / 255) - mean[0]) / std[0]
527
+ temp_tensor_1_renorm = ((temp_tensor_1 / 255) - mean[1]) / std[1]
528
+ temp_tensor_2_renorm = ((temp_tensor_2 / 255) - mean[2]) / std[2]
529
+
530
+ temp_tensor = torch.stack([temp_tensor_0_renorm, temp_tensor_1_renorm, temp_tensor_2_renorm], dim=0)
531
+
532
+ tensors_list.append(temp_tensor)
533
+
534
+ return torch.stack(tensors_list, dim=0), torch.stack(unnormalized_list, dim=0)
535
+
536
+
537
+ def initialize_tensors(tensors):
538
+ tensors_list = []
539
+ unnormalized_list = []
540
+ for i in range(len(tensors)):
541
+ temp_tensor = tensors[i]
542
+
543
+ temp_tensor_0 = (temp_tensor[0] * std[0] + mean[0]) * 255
544
+ temp_tensor_1 = (temp_tensor[1] * std[1] + mean[1]) * 255
545
+ temp_tensor_2 = (temp_tensor[2] * std[2] + mean[2]) * 255
546
+
547
+ rectangle_radius = 5 # 4+1+4=9
548
+
549
+ # end sign (when predict this, AR iteration terminates)
550
+ endsign = (505, 505)
551
+ valid_violet_endsign_up = endsign[1] - rectangle_radius
552
+ valid_violet_endsign_down = endsign[1] + rectangle_radius
553
+ valid_violet_endsign_left = endsign[0] - rectangle_radius
554
+ valid_violet_endsign_right = endsign[0] + rectangle_radius
555
+ temp_tensor_0[
556
+ valid_violet_endsign_up : valid_violet_endsign_down + 1,
557
+ valid_violet_endsign_left : valid_violet_endsign_right + 1,
558
+ ] = 255
559
+ temp_tensor_1[
560
+ valid_violet_endsign_up : valid_violet_endsign_down + 1,
561
+ valid_violet_endsign_left : valid_violet_endsign_right + 1,
562
+ ] = 0
563
+ temp_tensor_2[
564
+ valid_violet_endsign_up : valid_violet_endsign_down + 1,
565
+ valid_violet_endsign_left : valid_violet_endsign_right + 1,
566
+ ] = 255
567
+
568
+ unnormalized = torch.stack((temp_tensor_0, temp_tensor_1, temp_tensor_2), dim=0)
569
+ unnormalized_list.append(unnormalized)
570
+
571
+ temp_tensor_0_renorm = ((temp_tensor_0 / 255) - mean[0]) / std[0]
572
+ temp_tensor_1_renorm = ((temp_tensor_1 / 255) - mean[1]) / std[1]
573
+ temp_tensor_2_renorm = ((temp_tensor_2 / 255) - mean[2]) / std[2]
574
+
575
+ temp_tensor = torch.stack([temp_tensor_0_renorm, temp_tensor_1_renorm, temp_tensor_2_renorm], dim=0)
576
+
577
+ tensors_list.append(temp_tensor)
578
+
579
+ return torch.stack(tensors_list, dim=0), torch.stack(unnormalized_list, dim=0)
580
+
581
+
582
+ def l1_dist(pos1, pos2):
583
+ return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1])
584
+
585
+
586
+ def delete_graphs(targets):
587
+ no_graph_targets = []
588
+ for target in targets:
589
+ target_ = copy.deepcopy(target)
590
+ del target_["graph"]
591
+ no_graph_targets.append(target_)
592
+ return no_graph_targets
593
+
594
+
595
+ def delete_graphs_and_unnormpoints(targets):
596
+ no_graph_targets = []
597
+ for target in targets:
598
+ target_ = copy.deepcopy(target)
599
+ del target_["graph"]
600
+ del target_["unnormalized_points"]
601
+ no_graph_targets.append(target_)
602
+ return no_graph_targets
603
+
604
+
605
+ def get_remove_point(this_preds, dist_threshold):
606
+ for point1 in this_preds:
607
+ for point2 in this_preds:
608
+ # if point1 != point2:
609
+ if not (
610
+ (point1["points"].tolist()[0] == point2["points"].tolist()[0])
611
+ and (point1["points"].tolist()[1] == point2["points"].tolist()[1])
612
+ ):
613
+ dist_chebyshev = max(
614
+ abs(point1["points"].tolist()[0] - point2["points"].tolist()[0]),
615
+ abs(point1["points"].tolist()[1] - point2["points"].tolist()[1]),
616
+ )
617
+ if dist_chebyshev <= dist_threshold:
618
+ point1_confidence = point1["scores"].item()
619
+ point2_confidence = point2["scores"].item()
620
+ if point1_confidence < point2_confidence:
621
+ return point1
622
+ elif point2_confidence < point1_confidence:
623
+ return point2
624
+ else:
625
+ return [point1, point2][random.randint(0, 1)]
626
+ return None
627
+
628
+
629
+ def point_inside(point, points_list):
630
+ point1 = tuple(point["points"].tolist())
631
+ for point_i in points_list:
632
+ point1_i = tuple(point_i["points"].tolist())
633
+ if point1 == point1_i:
634
+ return True
635
+ return False
636
+
637
+
638
+ def remove_points(need_to_remove_in_last_edges, this_preds):
639
+ result = []
640
+ for this_pred in this_preds:
641
+ if not point_inside(this_pred, need_to_remove_in_last_edges):
642
+ result.append(this_pred)
643
+ return result
644
+
645
+
646
+ def nms(this_preds):
647
+ if len(this_preds) <= 1:
648
+ return this_preds
649
+ else:
650
+ dist_threshold = 5
651
+ while True:
652
+ remove_point = get_remove_point(this_preds, dist_threshold)
653
+ if remove_point is None:
654
+ break
655
+ else:
656
+ # this_preds.remove(remove_point)
657
+ this_preds = remove_points([remove_point], this_preds)
658
+
659
+ return this_preds
660
+
661
+
662
+ def nms_givenpoints(this_preds, preds):
663
+ if len(this_preds) == 0:
664
+ return this_preds
665
+ else:
666
+ all_given_points = []
667
+ for given_points, given_last_edges, given_this_edges in preds:
668
+ all_given_points.extend(given_points)
669
+ if len(all_given_points) == 0:
670
+ return this_preds
671
+ this_preds_copy = copy.deepcopy(this_preds)
672
+ dist_threshold = 5
673
+ for this_pred in this_preds_copy:
674
+ for given_point in all_given_points:
675
+ this_pred_pos = tuple(this_pred["points"].tolist())
676
+ given_point_pos = tuple(given_point["points"].tolist())
677
+ dist_chebyshev = max(
678
+ abs(this_pred_pos[0] - given_point_pos[0]), abs(this_pred_pos[1] - given_point_pos[1])
679
+ )
680
+ if dist_chebyshev <= dist_threshold:
681
+ this_preds = remove_points([this_pred], this_preds)
682
+ break
683
+ return this_preds
684
+
685
+
686
+ def random_keep(this_preds):
687
+ if len(this_preds) <= 1:
688
+ return this_preds
689
+ else:
690
+ while True:
691
+ random_keep_this_preds = []
692
+ for point in this_preds:
693
+ # is_keep = random.random() < point['scores'].item()
694
+ is_keep = random.random() < 1.01
695
+ # is_keep = random.random() < 0.5
696
+ if is_keep:
697
+ random_keep_this_preds.append(point)
698
+ if len(random_keep_this_preds) > 0:
699
+ return random_keep_this_preds
700
+
701
+
702
+ def is_stop(this_preds):
703
+ if len(this_preds) == 0:
704
+ return 1 # stop
705
+ elif (len(this_preds) >= 1) and (16 in [p["edges"].item() for p in this_preds]):
706
+ return 2 # normally terminate
707
+ else:
708
+ return 0 # not stop
709
+
710
+
711
+ def draw_preds_on_tensors(preds, tensors):
712
+ tensors_list = []
713
+ unnormalized_list = []
714
+
715
+ for i in range(len(tensors)):
716
+ temp_tensor = tensors[i]
717
+
718
+ temp_tensor_0 = (temp_tensor[0] * std[0] + mean[0]) * 255
719
+ temp_tensor_1 = (temp_tensor[1] * std[1] + mean[1]) * 255
720
+ temp_tensor_2 = (temp_tensor[2] * std[2] + mean[2]) * 255
721
+
722
+ rectangle_radius = 5
723
+
724
+ this_preds, last_edges, this_edges = preds[-1]
725
+ for this_pred in this_preds:
726
+ point = tuple([int(_) for _ in this_pred["points"].tolist()])
727
+ up = point[1] - rectangle_radius
728
+ down = point[1] + rectangle_radius
729
+ left = point[0] - rectangle_radius
730
+ right = point[0] + rectangle_radius
731
+ temp_tensor_0[up : down + 1, left : right + 1] = 255
732
+ temp_tensor_1[up : down + 1, left : right + 1] = 255
733
+ temp_tensor_2[up : down + 1, left : right + 1] = 0
734
+ line_width = 2
735
+ for last_edge in last_edges:
736
+ pos1 = tuple([int(_) for _ in last_edge[0]["points"].tolist()])
737
+ pos2 = tuple([int(_) for _ in last_edge[1]["points"].tolist()])
738
+ if abs(pos1[0] - pos2[0]) < abs(pos1[1] - pos2[1]):
739
+ if pos1[1] > pos2[1]:
740
+ temp_tensor_0[
741
+ pos2[1] : pos1[1] + 1,
742
+ int((pos1[0] + pos2[0]) / 2) - int(line_width / 2) : int((pos1[0] + pos2[0]) / 2)
743
+ + int(line_width / 2)
744
+ + 1,
745
+ ] = 0
746
+ temp_tensor_1[
747
+ pos2[1] : pos1[1] + 1,
748
+ int((pos1[0] + pos2[0]) / 2) - int(line_width / 2) : int((pos1[0] + pos2[0]) / 2)
749
+ + int(line_width / 2)
750
+ + 1,
751
+ ] = 0
752
+ temp_tensor_2[
753
+ pos2[1] : pos1[1] + 1,
754
+ int((pos1[0] + pos2[0]) / 2) - int(line_width / 2) : int((pos1[0] + pos2[0]) / 2)
755
+ + int(line_width / 2)
756
+ + 1,
757
+ ] = 255
758
+ else:
759
+ temp_tensor_0[
760
+ pos1[1] : pos2[1] + 1,
761
+ int((pos2[0] + pos1[0]) / 2) - int(line_width / 2) : int((pos2[0] + pos1[0]) / 2)
762
+ + int(line_width / 2)
763
+ + 1,
764
+ ] = 0
765
+ temp_tensor_1[
766
+ pos1[1] : pos2[1] + 1,
767
+ int((pos2[0] + pos1[0]) / 2) - int(line_width / 2) : int((pos2[0] + pos1[0]) / 2)
768
+ + int(line_width / 2)
769
+ + 1,
770
+ ] = 0
771
+ temp_tensor_2[
772
+ pos1[1] : pos2[1] + 1,
773
+ int((pos2[0] + pos1[0]) / 2) - int(line_width / 2) : int((pos2[0] + pos1[0]) / 2)
774
+ + int(line_width / 2)
775
+ + 1,
776
+ ] = 255
777
+ else:
778
+ if pos1[0] > pos2[0]:
779
+ temp_tensor_0[
780
+ int((pos1[1] + pos2[1]) / 2) - int(line_width / 2) : int((pos1[1] + pos2[1]) / 2)
781
+ + int(line_width / 2)
782
+ + 1,
783
+ pos2[0] : pos1[0] + 1,
784
+ ] = 0
785
+ temp_tensor_1[
786
+ int((pos1[1] + pos2[1]) / 2) - int(line_width / 2) : int((pos1[1] + pos2[1]) / 2)
787
+ + int(line_width / 2)
788
+ + 1,
789
+ pos2[0] : pos1[0] + 1,
790
+ ] = 0
791
+ temp_tensor_2[
792
+ int((pos1[1] + pos2[1]) / 2) - int(line_width / 2) : int((pos1[1] + pos2[1]) / 2)
793
+ + int(line_width / 2)
794
+ + 1,
795
+ pos2[0] : pos1[0] + 1,
796
+ ] = 255
797
+ else:
798
+ temp_tensor_0[
799
+ int((pos2[1] + pos1[1]) / 2) - int(line_width / 2) : int((pos2[1] + pos1[1]) / 2)
800
+ + int(line_width / 2)
801
+ + 1,
802
+ pos1[0] : pos2[0] + 1,
803
+ ] = 0
804
+ temp_tensor_1[
805
+ int((pos2[1] + pos1[1]) / 2) - int(line_width / 2) : int((pos2[1] + pos1[1]) / 2)
806
+ + int(line_width / 2)
807
+ + 1,
808
+ pos1[0] : pos2[0] + 1,
809
+ ] = 0
810
+ temp_tensor_2[
811
+ int((pos2[1] + pos1[1]) / 2) - int(line_width / 2) : int((pos2[1] + pos1[1]) / 2)
812
+ + int(line_width / 2)
813
+ + 1,
814
+ pos1[0] : pos2[0] + 1,
815
+ ] = 255
816
+ for this_edge in this_edges:
817
+ pos1 = tuple([int(_) for _ in this_edge[0]["points"].tolist()])
818
+ pos2 = tuple([int(_) for _ in this_edge[1]["points"].tolist()])
819
+ if abs(pos1[0] - pos2[0]) < abs(pos1[1] - pos2[1]):
820
+ if pos1[1] > pos2[1]:
821
+ temp_tensor_0[
822
+ pos2[1] : pos1[1] + 1,
823
+ int((pos1[0] + pos2[0]) / 2) - int(line_width / 2) : int((pos1[0] + pos2[0]) / 2)
824
+ + int(line_width / 2)
825
+ + 1,
826
+ ] = 0
827
+ temp_tensor_1[
828
+ pos2[1] : pos1[1] + 1,
829
+ int((pos1[0] + pos2[0]) / 2) - int(line_width / 2) : int((pos1[0] + pos2[0]) / 2)
830
+ + int(line_width / 2)
831
+ + 1,
832
+ ] = 0
833
+ temp_tensor_2[
834
+ pos2[1] : pos1[1] + 1,
835
+ int((pos1[0] + pos2[0]) / 2) - int(line_width / 2) : int((pos1[0] + pos2[0]) / 2)
836
+ + int(line_width / 2)
837
+ + 1,
838
+ ] = 255
839
+ else:
840
+ temp_tensor_0[
841
+ pos1[1] : pos2[1] + 1,
842
+ int((pos2[0] + pos1[0]) / 2) - int(line_width / 2) : int((pos2[0] + pos1[0]) / 2)
843
+ + int(line_width / 2)
844
+ + 1,
845
+ ] = 0
846
+ temp_tensor_1[
847
+ pos1[1] : pos2[1] + 1,
848
+ int((pos2[0] + pos1[0]) / 2) - int(line_width / 2) : int((pos2[0] + pos1[0]) / 2)
849
+ + int(line_width / 2)
850
+ + 1,
851
+ ] = 0
852
+ temp_tensor_2[
853
+ pos1[1] : pos2[1] + 1,
854
+ int((pos2[0] + pos1[0]) / 2) - int(line_width / 2) : int((pos2[0] + pos1[0]) / 2)
855
+ + int(line_width / 2)
856
+ + 1,
857
+ ] = 255
858
+ else:
859
+ if pos1[0] > pos2[0]:
860
+ temp_tensor_0[
861
+ int((pos1[1] + pos2[1]) / 2) - int(line_width / 2) : int((pos1[1] + pos2[1]) / 2)
862
+ + int(line_width / 2)
863
+ + 1,
864
+ pos2[0] : pos1[0] + 1,
865
+ ] = 0
866
+ temp_tensor_1[
867
+ int((pos1[1] + pos2[1]) / 2) - int(line_width / 2) : int((pos1[1] + pos2[1]) / 2)
868
+ + int(line_width / 2)
869
+ + 1,
870
+ pos2[0] : pos1[0] + 1,
871
+ ] = 0
872
+ temp_tensor_2[
873
+ int((pos1[1] + pos2[1]) / 2) - int(line_width / 2) : int((pos1[1] + pos2[1]) / 2)
874
+ + int(line_width / 2)
875
+ + 1,
876
+ pos2[0] : pos1[0] + 1,
877
+ ] = 255
878
+ else:
879
+ temp_tensor_0[
880
+ int((pos2[1] + pos1[1]) / 2) - int(line_width / 2) : int((pos2[1] + pos1[1]) / 2)
881
+ + int(line_width / 2)
882
+ + 1,
883
+ pos1[0] : pos2[0] + 1,
884
+ ] = 0
885
+ temp_tensor_1[
886
+ int((pos2[1] + pos1[1]) / 2) - int(line_width / 2) : int((pos2[1] + pos1[1]) / 2)
887
+ + int(line_width / 2)
888
+ + 1,
889
+ pos1[0] : pos2[0] + 1,
890
+ ] = 0
891
+ temp_tensor_2[
892
+ int((pos2[1] + pos1[1]) / 2) - int(line_width / 2) : int((pos2[1] + pos1[1]) / 2)
893
+ + int(line_width / 2)
894
+ + 1,
895
+ pos1[0] : pos2[0] + 1,
896
+ ] = 255
897
+
898
+ unnormalized = torch.stack((temp_tensor_0, temp_tensor_1, temp_tensor_2), dim=0)
899
+ unnormalized_list.append(unnormalized)
900
+
901
+ temp_tensor_0_renorm = ((temp_tensor_0 / 255) - mean[0]) / std[0]
902
+ temp_tensor_1_renorm = ((temp_tensor_1 / 255) - mean[1]) / std[1]
903
+ temp_tensor_2_renorm = ((temp_tensor_2 / 255) - mean[2]) / std[2]
904
+
905
+ temp_tensor = torch.stack([temp_tensor_0_renorm, temp_tensor_1_renorm, temp_tensor_2_renorm], dim=0)
906
+
907
+ tensors_list.append(temp_tensor)
908
+
909
+ return torch.stack(tensors_list, dim=0), torch.stack(unnormalized_list, dim=0)
910
+
911
+
912
+ def edge_inside(edge, edges_list):
913
+ edge_point1 = tuple(edge[0]["points"].tolist())
914
+ edge_point2 = tuple(edge[1]["points"].tolist())
915
+ for edge_i in edges_list:
916
+ edge_i_point1 = tuple(edge_i[0]["points"].tolist())
917
+ edge_i_point2 = tuple(edge_i[1]["points"].tolist())
918
+ if ((edge_point1 == edge_i_point1) and (edge_point2 == edge_i_point2)) or (
919
+ (edge_point1 == edge_i_point2) and (edge_point2 == edge_i_point1)
920
+ ):
921
+ return True
922
+ return False
923
+
924
+
925
+ def remove_edge(edge, edges_list):
926
+ result = []
927
+ edge_point1 = tuple(edge[0]["points"].tolist())
928
+ edge_point2 = tuple(edge[1]["points"].tolist())
929
+ for edge_i in edges_list:
930
+ edge_i_point1 = tuple(edge_i[0]["points"].tolist())
931
+ edge_i_point2 = tuple(edge_i[1]["points"].tolist())
932
+ if (edge_point1 == edge_i_point1) and (edge_point2 == edge_i_point2):
933
+ pass
934
+ else:
935
+ result.append(edge_i)
936
+ return result
937
+
938
+
939
+ def get_edges_amount(preds):
940
+ count = 0
941
+ for this_preds, last_edges, this_edges in preds:
942
+ count += len(last_edges)
943
+ count += len(this_edges)
944
+ return count
945
+
946
+
947
+ def get_reserve_preds(results, keep_confidence_threshold, targets):
948
+ reserve_preds = []
949
+
950
+ valid_label_indices_edges = torch.where(results["edges"] != 0)[0]
951
+ valid_label_indices_scores = torch.where(results["scores"] <= keep_confidence_threshold)[0]
952
+ valid_label_indices = torch.tensor(
953
+ list(set(valid_label_indices_edges.tolist()).intersection(set(valid_label_indices_scores.tolist()))),
954
+ dtype=valid_label_indices_edges.dtype,
955
+ device=valid_label_indices_edges.device,
956
+ )
957
+ for valid_label_indice in valid_label_indices:
958
+ valid_results_i = {}
959
+ valid_results_i["scores"] = results["scores"][valid_label_indice]
960
+ valid_results_i["points"] = results["points"][valid_label_indice]
961
+ valid_results_i["last_edges"] = results["last_edges"][valid_label_indice]
962
+ valid_results_i["this_edges"] = results["this_edges"][valid_label_indice]
963
+ valid_results_i["edges"] = results["edges"][valid_label_indice]
964
+ valid_results_i["size"] = targets[0]["size"]
965
+ reserve_preds.append(valid_results_i)
966
+ return reserve_preds
data_preprocess/raster2graph/util/edges_utils.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ edges = {
2
+ 0: "0000",
3
+ 1: "0001",
4
+ 2: "0010",
5
+ 3: "0011",
6
+ 4: "0100",
7
+ 5: "0110",
8
+ 6: "0111",
9
+ 7: "1000",
10
+ 8: "1001",
11
+ 9: "1011",
12
+ 10: "1100",
13
+ 11: "1101",
14
+ 12: "1110",
15
+ 13: "1111",
16
+ 14: "0101",
17
+ 15: "1010",
18
+ }
19
+
20
+
21
+ def get_edges_alldirections(edges_class):
22
+ return edges[edges_class]
23
+
24
+
25
+ edges_rev = {
26
+ "0000": 0,
27
+ "0001": 1,
28
+ "0010": 2,
29
+ "0011": 3,
30
+ "0100": 4,
31
+ "0110": 5,
32
+ "0111": 6,
33
+ "1000": 7,
34
+ "1001": 8,
35
+ "1011": 9,
36
+ "1100": 10,
37
+ "1101": 11,
38
+ "1110": 12,
39
+ "1111": 13,
40
+ "0101": 14,
41
+ "1010": 15,
42
+ }
43
+
44
+
45
+ def get_edges_alldirections_rev(edges_class_rev):
46
+ return edges_rev[edges_class_rev]
data_preprocess/raster2graph/util/geom_utils.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ from shapely.geometry import Polygon
4
+
5
+
6
+ def poly_iou(poly1: Polygon, poly2: Polygon):
7
+ try:
8
+ intersection_area = poly1.intersection(poly2).area
9
+ union_area = poly1.union(poly2).area
10
+ return intersection_area / union_area
11
+ except Exception:
12
+ poly1 = poly1.buffer(1)
13
+ poly2 = poly2.buffer(1)
14
+ intersection_area = poly1.intersection(poly2).area
15
+ union_area = poly1.union(poly2).area
16
+ return intersection_area / union_area
17
+
18
+
19
+ def is_clockwise_or_not(points):
20
+ s = 0
21
+ for i in range(0, len(points) - 1):
22
+ s += points[i][0] * points[i + 1][1] - points[i][1] * points[i + 1][0]
23
+ return s > 0
24
+
25
+
26
+ def x_axis_angle(y):
27
+ # 以图像坐标系为准,(1,0)方向记为0度,逆时针绕一圈到360度
28
+ # print('-------------')
29
+ # print(y)
30
+ y_right_hand = (y[0], -y[1])
31
+ # print(y_right_hand)
32
+
33
+ x = (1, 0)
34
+ inner = x[0] * y_right_hand[0] + x[1] * y_right_hand[1]
35
+ # print(inner)
36
+ y_norm2 = (y_right_hand[0] ** 2 + y_right_hand[1] ** 2) ** 0.5
37
+ # print(y_norm2)
38
+ cosxy = inner / (y_norm2 + 1e-8)
39
+ # print(cosxy)
40
+ angle = math.acos(cosxy)
41
+ # print(angle, math.degrees(angle))
42
+ # print('-------------')
43
+ return math.degrees(angle) if y_right_hand[1] >= 0 else 360 - math.degrees(angle)
44
+
45
+
46
+ def get_quadrant(angle):
47
+ if angle[0] < angle[1]:
48
+ if 0 <= angle[0] < 90 and 0 <= angle[1] < 90:
49
+ quadrant = (angle[1] - angle[0], 0, 0, 0)
50
+ elif 0 <= angle[0] < 90 and 90 <= angle[1] < 180:
51
+ quadrant = (90 - angle[0], angle[1] - 90, 0, 0)
52
+ elif 0 <= angle[0] < 90 and 180 <= angle[1] < 270:
53
+ quadrant = (90 - angle[0], 90, angle[1] - 180, 0)
54
+ elif 0 <= angle[0] < 90 and 270 <= angle[1] < 360:
55
+ quadrant = (90 - angle[0], 90, 90, angle[1] - 270)
56
+ elif 90 <= angle[0] < 180 and 90 <= angle[1] < 180:
57
+ quadrant = (0, angle[1] - angle[0], 0, 0)
58
+ elif 90 <= angle[0] < 180 and 180 <= angle[1] < 270:
59
+ quadrant = (0, 180 - angle[0], angle[1] - 180, 0)
60
+ elif 90 <= angle[0] < 180 and 270 <= angle[1] < 360:
61
+ quadrant = (0, 180 - angle[0], 90, angle[1] - 270)
62
+ elif 180 <= angle[0] < 270 and 180 <= angle[1] < 270:
63
+ quadrant = (0, 0, angle[1] - angle[0], 0)
64
+ elif 180 <= angle[0] < 270 and 270 <= angle[1] < 360:
65
+ quadrant = (0, 0, 270 - angle[0], angle[1] - 270)
66
+ elif 270 <= angle[0] < 360 and 270 <= angle[1] < 360:
67
+ quadrant = (0, 0, 0, angle[1] - angle[0])
68
+ else:
69
+ if 0 <= angle[1] < 90 and 0 <= angle[0] < 90:
70
+ quadrant_ = (angle[0] - angle[1], 0, 0, 0)
71
+ elif 0 <= angle[1] < 90 and 90 <= angle[0] < 180:
72
+ quadrant_ = (90 - angle[1], angle[0] - 90, 0, 0)
73
+ elif 0 <= angle[1] < 90 and 180 <= angle[0] < 270:
74
+ quadrant_ = (90 - angle[1], 90, angle[0] - 180, 0)
75
+ elif 0 <= angle[1] < 90 and 270 <= angle[0] < 360:
76
+ quadrant_ = (90 - angle[1], 90, 90, angle[0] - 270)
77
+ elif 90 <= angle[1] < 180 and 90 <= angle[0] < 180:
78
+ quadrant_ = (0, angle[0] - angle[1], 0, 0)
79
+ elif 90 <= angle[1] < 180 and 180 <= angle[0] < 270:
80
+ quadrant_ = (0, 180 - angle[1], angle[0] - 180, 0)
81
+ elif 90 <= angle[1] < 180 and 270 <= angle[0] < 360:
82
+ quadrant_ = (0, 180 - angle[1], 90, angle[0] - 270)
83
+ elif 180 <= angle[1] < 270 and 180 <= angle[0] < 270:
84
+ quadrant_ = (0, 0, angle[0] - angle[1], 0)
85
+ elif 180 <= angle[1] < 270 and 270 <= angle[0] < 360:
86
+ quadrant_ = (0, 0, 270 - angle[1], angle[0] - 270)
87
+ elif 270 <= angle[1] < 360 and 270 <= angle[0] < 360:
88
+ quadrant_ = (0, 0, 0, angle[0] - angle[1])
89
+ quadrant = (90 - quadrant_[0], 90 - quadrant_[1], 90 - quadrant_[2], 90 - quadrant_[3])
90
+ return quadrant
91
+
92
+
93
+ def find_which_angle_to_counterclockwise_rotate_from(t):
94
+ if t > 270:
95
+ return 630 - t
96
+ else:
97
+ return 270 - t
98
+
99
+
100
+ def counter_degree(d):
101
+ if d >= 180:
102
+ return d - 180
103
+ else:
104
+ return d + 180
105
+
106
+
107
+ def rotate_degree_clockwise_from_counter_degree(src_degree, dest_degree):
108
+ delta = src_degree - dest_degree
109
+ return delta if delta >= 0 else 360 + delta
110
+
111
+
112
+ def rotate_degree_counterclockwise_from_counter_degree(src_degree, dest_degree):
113
+ delta = dest_degree - src_degree
114
+ return delta if delta >= 0 else 360 + delta
115
+
116
+
117
+ def poly_area(points):
118
+ s = 0
119
+ points_count = len(points)
120
+ for i in range(points_count):
121
+ point = points[i]
122
+ point2 = points[(i + 1) % points_count]
123
+ s += (point[0] - point2[0]) * (point[1] + point2[1])
124
+ return s / 2
data_preprocess/raster2graph/util/graph_utils.py ADDED
@@ -0,0 +1,879 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import random
3
+
4
+ import networkx as nx
5
+ import numpy as np
6
+ import torch
7
+ from util.geom_utils import (
8
+ get_quadrant,
9
+ is_clockwise_or_not,
10
+ poly_area,
11
+ rotate_degree_counterclockwise_from_counter_degree,
12
+ x_axis_angle,
13
+ )
14
+ from util.metric_utils import get_results, get_results_float_with_semantic
15
+
16
+
17
+ def graph_to_tensor(graph):
18
+ t_l = []
19
+ for k, v in graph.items():
20
+ a = []
21
+ a.append(k)
22
+ a.extend(v)
23
+ b = [list(i) for i in a]
24
+ c = torch.tensor(b)
25
+ t_l.append(c)
26
+ return torch.stack(t_l, dim=0)
27
+
28
+
29
+ def tensor_to_graph(tensor):
30
+ gr = {}
31
+ for kv in tensor:
32
+ k = tuple([i.item() for i in kv[0]])
33
+ v = kv[1:5]
34
+ v = v.tolist()
35
+ v = [tuple(i) for i in v]
36
+ gr[k] = v
37
+ return gr
38
+
39
+
40
+ def tensors_to_graphs_batch(tensors):
41
+ return [tensor_to_graph(ts) for ts in tensors]
42
+
43
+
44
+ def get_cycle_basis_and_semantic_deprecated(best_result):
45
+ output_points, output_edges = get_results_float_with_semantic(best_result)
46
+ d = {}
47
+ for output_point_index, output_point in enumerate(output_points):
48
+ d[output_point] = output_point_index
49
+ d_rev = {}
50
+ for output_point_index, output_point in enumerate(output_points):
51
+ d_rev[output_point_index] = output_point
52
+ es = []
53
+ for output_edge in output_edges:
54
+ es.append((d[output_edge[0]], d[output_edge[1]]))
55
+
56
+ G = nx.Graph()
57
+ for e in es:
58
+ G.add_edge(e[0], e[1])
59
+
60
+ nx.draw(G)
61
+ # plt.show()
62
+ simple_cycles = nx.cycle_basis(G)
63
+
64
+ results = []
65
+
66
+ for cycle_ind, cycle in enumerate(simple_cycles):
67
+ points = [d_rev[ind] for ind in cycle]
68
+ points.append(points[0])
69
+
70
+ is_clockwise = is_clockwise_or_not([(p[0], p[1]) for p in points])
71
+ if is_clockwise:
72
+ points.reverse()
73
+
74
+ cross_products = []
75
+ poses = [(p[0], p[1]) for p in points]
76
+ for ind in range(len(poses) - 1):
77
+ ei = [
78
+ poses[(ind + 1) % (len(poses) - 1)][0] - poses[ind][0],
79
+ poses[(ind + 1) % (len(poses) - 1)][1] - poses[ind][1],
80
+ ]
81
+ eiplus1 = [
82
+ poses[(ind + 2) % (len(poses) - 1)][0] - poses[(ind + 1) % (len(poses) - 1)][0],
83
+ poses[(ind + 2) % (len(poses) - 1)][1] - poses[(ind + 1) % (len(poses) - 1)][1],
84
+ ]
85
+ cross_products.append(np.cross(ei, eiplus1).tolist())
86
+ cross_products.insert(0, cross_products[-1])
87
+ cross_products.pop(-1)
88
+
89
+ while 0 in cross_products:
90
+ for point_ind, cross_product in enumerate(cross_products):
91
+ if cross_product == 0:
92
+ if point_ind == 0:
93
+ p0 = copy.deepcopy(points[0])
94
+ points[0] = (
95
+ p0[0] + 0.000001 * random.random() * [-1, 1][random.randint(0, 1)],
96
+ p0[1] + 0.000001 * random.random() * [-1, 1][random.randint(0, 1)],
97
+ p0[2],
98
+ p0[3],
99
+ p0[4],
100
+ p0[5],
101
+ )
102
+ points[-1] = copy.deepcopy(points[0])
103
+ else:
104
+ pi = copy.deepcopy(points[point_ind])
105
+ points[point_ind] = (
106
+ pi[0] + 0.000001 * random.random() * [-1, 1][random.randint(0, 1)],
107
+ pi[1] + 0.000001 * random.random() * [-1, 1][random.randint(0, 1)],
108
+ pi[2],
109
+ pi[3],
110
+ pi[4],
111
+ pi[5],
112
+ )
113
+ # print(points)
114
+ cross_products = []
115
+ poses = [(p[0], p[1]) for p in points]
116
+ for ind in range(len(poses) - 1):
117
+ ei = [
118
+ poses[(ind + 1) % (len(poses) - 1)][0] - poses[ind][0],
119
+ poses[(ind + 1) % (len(poses) - 1)][1] - poses[ind][1],
120
+ ]
121
+ eiplus1 = [
122
+ poses[(ind + 2) % (len(poses) - 1)][0] - poses[(ind + 1) % (len(poses) - 1)][0],
123
+ poses[(ind + 2) % (len(poses) - 1)][1] - poses[(ind + 1) % (len(poses) - 1)][1],
124
+ ]
125
+ cross_products.append(np.cross(ei, eiplus1).tolist())
126
+ cross_products.insert(0, cross_products[-1])
127
+ cross_products.pop(-1)
128
+
129
+ semantics = [[p[2], p[3], p[4], p[5]] for p in points]
130
+
131
+ degrees = []
132
+ for ind in range(len(poses) - 1):
133
+ ei_minus = [
134
+ -(poses[(ind + 1) % (len(poses) - 1)][0] - poses[ind][0]),
135
+ -(poses[(ind + 1) % (len(poses) - 1)][1] - poses[ind][1]),
136
+ ]
137
+
138
+ eiplus1 = [
139
+ poses[(ind + 2) % (len(poses) - 1)][0] - poses[(ind + 1) % (len(poses) - 1)][0],
140
+ poses[(ind + 2) % (len(poses) - 1)][1] - poses[(ind + 1) % (len(poses) - 1)][1],
141
+ ]
142
+
143
+ degrees.append((x_axis_angle(ei_minus), x_axis_angle(eiplus1)))
144
+ degrees.insert(0, degrees[-1])
145
+ degrees.pop(-1)
146
+
147
+ angles = []
148
+ for degree in degrees:
149
+ angles.append(((min(degree), max(degree)), (max(degree), min(degree))))
150
+
151
+ angles_to_semantics = []
152
+ for angle_ind, angle in enumerate(angles):
153
+ angle1 = angle[0]
154
+ angle2 = angle[1]
155
+ quadrant1 = get_quadrant(angle1)
156
+ quadrant2 = get_quadrant(angle2)
157
+
158
+ semantic1 = (
159
+ semantics[angle_ind][1] if quadrant1[0] >= 45 else -1,
160
+ semantics[angle_ind][0] if quadrant1[1] >= 45 else -1,
161
+ semantics[angle_ind][3] if quadrant1[2] >= 45 else -1,
162
+ semantics[angle_ind][2] if quadrant1[3] >= 45 else -1,
163
+ )
164
+ semantic2 = (
165
+ semantics[angle_ind][1] if quadrant2[0] >= 45 else -1,
166
+ semantics[angle_ind][0] if quadrant2[1] >= 45 else -1,
167
+ semantics[angle_ind][3] if quadrant2[2] >= 45 else -1,
168
+ semantics[angle_ind][2] if quadrant2[3] >= 45 else -1,
169
+ )
170
+
171
+ angle1_degree = sum(quadrant1)
172
+ angle2_degree = sum(quadrant2)
173
+
174
+ xproduct = cross_products[angle_ind]
175
+
176
+ if xproduct < 0:
177
+ if angle1_degree < angle2_degree:
178
+ angles_to_semantics.append(semantic1)
179
+ else:
180
+ angles_to_semantics.append(semantic2)
181
+ elif xproduct > 0:
182
+ if angle1_degree < angle2_degree:
183
+ angles_to_semantics.append(semantic2)
184
+ else:
185
+ angles_to_semantics.append(semantic1)
186
+ else:
187
+ assert 0
188
+
189
+ semantic_result = {}
190
+ for semantic_label in range(0, 13):
191
+ semantic_result[semantic_label] = 0
192
+ for everypoint_semantic in angles_to_semantics:
193
+ everypoint_semantic = [s for s in everypoint_semantic if s != -1]
194
+ for label in everypoint_semantic:
195
+ semantic_result[label] += 1 / len(everypoint_semantic)
196
+
197
+ this_cycle_semantic1 = sorted(semantic_result.items(), key=lambda d: d[1], reverse=True)
198
+ this_cycle_result = None
199
+ if this_cycle_semantic1[0][1] > this_cycle_semantic1[1][1]:
200
+ this_cycle_result = this_cycle_semantic1[0][0]
201
+ else:
202
+ this_cycle_results = [i[0] for i in this_cycle_semantic1 if i[1] == this_cycle_semantic1[0][1]]
203
+ this_cycle_result = this_cycle_results[random.randint(0, len(this_cycle_results) - 1)]
204
+ results.append(this_cycle_result)
205
+
206
+ return d_rev, simple_cycles, results
207
+
208
+
209
+ def get_cycle_basis_and_semantic(best_result):
210
+ output_points, output_edges = get_results_float_with_semantic(best_result)
211
+ output_points = copy.deepcopy(output_points)
212
+ output_edges = copy.deepcopy(output_edges)
213
+
214
+ d = {}
215
+ for output_point_index, output_point in enumerate(output_points):
216
+ d[output_point] = output_point_index
217
+ d_rev = {}
218
+ for output_point_index, output_point in enumerate(output_points):
219
+ d_rev[output_point_index] = output_point
220
+ es = []
221
+ for output_edge in output_edges:
222
+ es.append((d[output_edge[0]], d[output_edge[1]]))
223
+ # print(d)
224
+
225
+ G = nx.Graph()
226
+ for e in es:
227
+ G.add_edge(e[0], e[1])
228
+
229
+ simple_cycles = []
230
+ simple_cycles_number = []
231
+ simple_cycles_semantics = []
232
+ bridges = list(nx.bridges(G))
233
+ for b in bridges:
234
+ if (d_rev[b[0]], d_rev[b[1]]) in output_edges:
235
+ output_edges.remove((d_rev[b[0]], d_rev[b[1]]))
236
+ es.remove((b[0], b[1]))
237
+ G.remove_edge(b[0], b[1])
238
+ if (d_rev[b[1]], d_rev[b[0]]) in output_edges:
239
+ output_edges.remove((d_rev[b[1]], d_rev[b[0]]))
240
+ es.remove((b[1], b[0]))
241
+ G.remove_edge(b[1], b[0])
242
+ connected_components = list(nx.connected_components(G))
243
+ for c in connected_components:
244
+ if len(c) == 1:
245
+ pass
246
+ else:
247
+ simple_cycles_c = []
248
+ simple_cycles_number_c = []
249
+ simple_cycle_semantics_c = []
250
+ # output_points_c = [p for p in output_points if d[p] in c]
251
+ output_edges_c = [e for e in output_edges if d[e[0]] in c or d[e[1]] in c]
252
+ output_edges_c_copy_for_traversing = copy.deepcopy(output_edges_c)
253
+
254
+ for edge_c in output_edges_c:
255
+ if edge_c not in output_edges_c_copy_for_traversing:
256
+ pass
257
+ else:
258
+ simple_cycle_semantics = []
259
+ simple_cycle = []
260
+ simple_cycle_number = []
261
+ point1 = edge_c[0]
262
+ point2 = edge_c[1]
263
+ point1_number = d[point1]
264
+ point2_number = d[point2]
265
+
266
+ initial_point = None
267
+ initial_point_number = None
268
+ if point1_number < point2_number:
269
+ initial_point = point1
270
+ initial_point_number = point1_number
271
+ else:
272
+ initial_point = point2
273
+ initial_point_number = point2_number
274
+ simple_cycle.append(initial_point)
275
+ simple_cycle_number.append(initial_point_number)
276
+
277
+ last_point = initial_point
278
+
279
+ current_point = None
280
+ current_point_number = None
281
+ if point1_number < point2_number:
282
+ current_point = point2
283
+ current_point_number = point2_number
284
+ else:
285
+ current_point = point1
286
+ current_point_number = point1_number
287
+ simple_cycle.append(current_point)
288
+ simple_cycle_number.append(current_point_number)
289
+
290
+ next_initial_point = copy.deepcopy(current_point)
291
+
292
+ next_point = None
293
+ next_point_number = None
294
+
295
+ while next_point != next_initial_point:
296
+ relevant_edges = []
297
+ for edge in output_edges_c:
298
+ if edge[0] == current_point or edge[1] == current_point:
299
+ relevant_edges.append(edge)
300
+
301
+ relevant_edges_degree = []
302
+ for relevant_edge in relevant_edges:
303
+ vec = None
304
+ if relevant_edge[0] == current_point:
305
+ vec = (
306
+ relevant_edge[1][0] - relevant_edge[0][0],
307
+ relevant_edge[1][1] - relevant_edge[0][1],
308
+ )
309
+ elif relevant_edge[1] == current_point:
310
+ vec = (
311
+ relevant_edge[0][0] - relevant_edge[1][0],
312
+ relevant_edge[0][1] - relevant_edge[1][1],
313
+ )
314
+ else:
315
+ assert 0
316
+
317
+ vec_degree = x_axis_angle(vec)
318
+ relevant_edges_degree.append(vec_degree)
319
+
320
+ vec_from_current_point_to_last_point_degree = None
321
+ for relevant_edge_ind, relevant_edge in enumerate(relevant_edges):
322
+ if relevant_edge == (current_point, last_point):
323
+ vec_from_current_point_to_last_point_degree = relevant_edges_degree[relevant_edge_ind]
324
+ relevant_edges.remove(relevant_edge)
325
+ relevant_edges_degree.remove(vec_from_current_point_to_last_point_degree)
326
+ elif relevant_edge == (last_point, current_point):
327
+ vec_from_current_point_to_last_point_degree = relevant_edges_degree[relevant_edge_ind]
328
+ relevant_edges.remove(relevant_edge)
329
+ relevant_edges_degree.remove(vec_from_current_point_to_last_point_degree)
330
+ else:
331
+ continue
332
+
333
+ rotate_deltas_counterclockwise = []
334
+
335
+ interior_angles = []
336
+ for relevant_edge_degree in relevant_edges_degree:
337
+ rotate_delta = rotate_degree_counterclockwise_from_counter_degree(
338
+ vec_from_current_point_to_last_point_degree, relevant_edge_degree
339
+ )
340
+ rotate_deltas_counterclockwise.append(rotate_delta)
341
+ interior_angles.append((relevant_edge_degree, vec_from_current_point_to_last_point_degree))
342
+ # print(rotate_deltas_counterclockwise)
343
+
344
+ max_rotate_index = rotate_deltas_counterclockwise.index(max(rotate_deltas_counterclockwise))
345
+
346
+ interior_angle_counterclockwise = interior_angles[max_rotate_index]
347
+
348
+ current_point_semantic = [
349
+ current_point[3],
350
+ current_point[2],
351
+ current_point[5],
352
+ current_point[4],
353
+ ]
354
+
355
+ interior_angle_counterclockwise_degree_smaller = min(interior_angle_counterclockwise)
356
+ interior_angle_counterclockwise_degree_bigger = max(interior_angle_counterclockwise)
357
+ quadrant_smaller_to_bigger_counterclockwise = get_quadrant(
358
+ (
359
+ interior_angle_counterclockwise_degree_smaller,
360
+ interior_angle_counterclockwise_degree_bigger,
361
+ )
362
+ )
363
+ # print(quadrant_smaller_to_bigger_counterclockwise)
364
+ if interior_angle_counterclockwise.index(interior_angle_counterclockwise_degree_smaller) == 0:
365
+ pass
366
+ elif (
367
+ interior_angle_counterclockwise.index(interior_angle_counterclockwise_degree_smaller) == 1
368
+ ):
369
+ quadrant_smaller_to_bigger_counterclockwise = (
370
+ 90 - quadrant_smaller_to_bigger_counterclockwise[0],
371
+ 90 - quadrant_smaller_to_bigger_counterclockwise[1],
372
+ 90 - quadrant_smaller_to_bigger_counterclockwise[2],
373
+ 90 - quadrant_smaller_to_bigger_counterclockwise[3],
374
+ )
375
+ else:
376
+ assert 0
377
+
378
+ current_point_semantic_valid = []
379
+ for qd, seman in enumerate(current_point_semantic):
380
+ if quadrant_smaller_to_bigger_counterclockwise[qd] >= 45:
381
+ current_point_semantic_valid.append(seman)
382
+ else:
383
+ current_point_semantic_valid.append(-1)
384
+
385
+ simple_cycle_semantics.append(current_point_semantic_valid)
386
+
387
+ max_rotate_edge = relevant_edges[max_rotate_index]
388
+
389
+ if max_rotate_edge[0] == current_point:
390
+ next_point = max_rotate_edge[1]
391
+ next_point_number = d[next_point]
392
+ elif max_rotate_edge[1] == current_point:
393
+ next_point = max_rotate_edge[0]
394
+ next_point_number = d[next_point]
395
+ else:
396
+ assert 0
397
+
398
+ last_point = current_point
399
+ current_point = next_point
400
+ current_point_number = next_point_number
401
+ simple_cycle.append(current_point)
402
+ simple_cycle_number.append(current_point_number)
403
+
404
+ for point_number_ind, point_number in enumerate(simple_cycle_number):
405
+ if point_number_ind < len(simple_cycle_number) - 1:
406
+ edge_number = (point_number, simple_cycle_number[point_number_ind + 1])
407
+ # print(simple_cycle_number)
408
+ if edge_number[0] < edge_number[1]:
409
+ if (
410
+ d_rev[edge_number[0]],
411
+ d_rev[edge_number[1]],
412
+ ) in output_edges_c_copy_for_traversing:
413
+ output_edges_c_copy_for_traversing.remove(
414
+ (d_rev[edge_number[0]], d_rev[edge_number[1]])
415
+ )
416
+ elif (
417
+ d_rev[edge_number[1]],
418
+ d_rev[edge_number[0]],
419
+ ) in output_edges_c_copy_for_traversing:
420
+ output_edges_c_copy_for_traversing.remove(
421
+ (d_rev[edge_number[1]], d_rev[edge_number[0]])
422
+ )
423
+
424
+ simple_cycle.pop(-1)
425
+ simple_cycle_number.pop(-1)
426
+
427
+ polygon_counterclockwise = [(int(p[0]), -int(p[1])) for p in simple_cycle]
428
+ polygon_counterclockwise.pop(-1)
429
+ # print('poly_area(polygon_counterclockwise)', poly_area(polygon_counterclockwise))
430
+ if poly_area(polygon_counterclockwise) > 0:
431
+ simple_cycles_c.append(simple_cycle)
432
+ simple_cycles_number_c.append(simple_cycle_number)
433
+
434
+ semantic_result = {}
435
+ for semantic_label in range(0, 13):
436
+ semantic_result[semantic_label] = 0
437
+ for everypoint_semantic in simple_cycle_semantics:
438
+ everypoint_semantic = [s for s in everypoint_semantic if s != -1]
439
+ for label in everypoint_semantic:
440
+ semantic_result[label] += 1 / len(everypoint_semantic)
441
+ # print(semantic_result)
442
+ del semantic_result[11]
443
+ del semantic_result[12]
444
+
445
+ this_cycle_semantic = sorted(semantic_result.items(), key=lambda d: d[1], reverse=True)
446
+ # print(this_cycle_semantic)
447
+ this_cycle_result = None
448
+ if this_cycle_semantic[0][1] > this_cycle_semantic[1][1]:
449
+ this_cycle_result = this_cycle_semantic[0][0]
450
+ else:
451
+ this_cycle_results = [
452
+ i[0] for i in this_cycle_semantic if i[1] == this_cycle_semantic[0][1]
453
+ ]
454
+ this_cycle_result = this_cycle_results[random.randint(0, len(this_cycle_results) - 1)]
455
+ # print(this_cycle_result)
456
+ simple_cycle_semantics_c.append(this_cycle_result)
457
+
458
+ simple_cycles.extend(simple_cycles_c)
459
+ simple_cycles_number.extend(simple_cycles_number_c)
460
+ simple_cycles_semantics.extend(simple_cycle_semantics_c)
461
+
462
+ # print([[(int(j[0]), int(j[1])) for j in i] for i in simple_cycles])
463
+
464
+ # print(len(simple_cycles_number))
465
+ # print(simple_cycles_semantics)
466
+
467
+ return d_rev, simple_cycles, simple_cycles_semantics
468
+
469
+
470
+ def get_cycle_basis_and_semantic_2(best_result):
471
+ output_points, output_edges = get_results_float_with_semantic(best_result)
472
+ output_points = copy.deepcopy(output_points)
473
+ output_edges = copy.deepcopy(output_edges)
474
+ # print(output_points)
475
+ # print(output_edges)
476
+ # assert 0
477
+ d = {}
478
+ for output_point_index, output_point in enumerate(output_points):
479
+ d[output_point] = output_point_index
480
+ d_rev = {}
481
+ for output_point_index, output_point in enumerate(output_points):
482
+ d_rev[output_point_index] = output_point
483
+ es = []
484
+ for output_edge in output_edges:
485
+ es.append((d[output_edge[0]], d[output_edge[1]]))
486
+ # print(d)
487
+
488
+ G = nx.Graph()
489
+ for e in es:
490
+ G.add_edge(e[0], e[1])
491
+
492
+ simple_cycles = []
493
+ simple_cycles_number = []
494
+ simple_cycles_semantics = []
495
+
496
+ bridges = list(nx.bridges(G))
497
+
498
+ for b in bridges:
499
+ if (d_rev[b[0]], d_rev[b[1]]) in output_edges:
500
+ output_edges.remove((d_rev[b[0]], d_rev[b[1]]))
501
+ es.remove((b[0], b[1]))
502
+ G.remove_edge(b[0], b[1])
503
+ if (d_rev[b[1]], d_rev[b[0]]) in output_edges:
504
+ output_edges.remove((d_rev[b[1]], d_rev[b[0]]))
505
+ es.remove((b[1], b[0]))
506
+ G.remove_edge(b[1], b[0])
507
+
508
+ connected_components = list(nx.connected_components(G))
509
+ for c in connected_components:
510
+ if len(c) == 1:
511
+ pass
512
+ else:
513
+ simple_cycles_c = []
514
+ simple_cycles_number_c = []
515
+ simple_cycle_semantics_c = []
516
+ output_edges_c = [e for e in output_edges if d[e[0]] in c or d[e[1]] in c]
517
+ output_edges_c_copy_for_traversing = copy.deepcopy(output_edges_c)
518
+
519
+ for edge_c in output_edges_c:
520
+ if edge_c not in output_edges_c_copy_for_traversing:
521
+ pass
522
+ else:
523
+ simple_cycle_semantics = []
524
+ simple_cycle = []
525
+ simple_cycle_number = []
526
+ point1 = edge_c[0]
527
+ point2 = edge_c[1]
528
+ point1_number = d[point1]
529
+ point2_number = d[point2]
530
+
531
+ initial_point = None
532
+ initial_point_number = None
533
+ if point1_number < point2_number:
534
+ initial_point = point1
535
+ initial_point_number = point1_number
536
+ else:
537
+ initial_point = point2
538
+ initial_point_number = point2_number
539
+ simple_cycle.append(initial_point)
540
+ simple_cycle_number.append(initial_point_number)
541
+
542
+ last_point = initial_point
543
+
544
+ current_point = None
545
+ current_point_number = None
546
+ if point1_number < point2_number:
547
+ current_point = point2
548
+ current_point_number = point2_number
549
+ else:
550
+ current_point = point1
551
+ current_point_number = point1_number
552
+ simple_cycle.append(current_point)
553
+ simple_cycle_number.append(current_point_number)
554
+
555
+ next_initial_point = copy.deepcopy(current_point)
556
+
557
+ next_point = None
558
+ next_point_number = None
559
+
560
+ while next_point != next_initial_point:
561
+ relevant_edges = []
562
+ for edge in output_edges_c:
563
+ if edge[0] == current_point or edge[1] == current_point:
564
+ relevant_edges.append(edge)
565
+
566
+ relevant_edges_degree = []
567
+ for relevant_edge in relevant_edges:
568
+ vec = None
569
+ if relevant_edge[0] == current_point:
570
+ vec = (
571
+ relevant_edge[1][0] - relevant_edge[0][0],
572
+ relevant_edge[1][1] - relevant_edge[0][1],
573
+ )
574
+ elif relevant_edge[1] == current_point:
575
+ vec = (
576
+ relevant_edge[0][0] - relevant_edge[1][0],
577
+ relevant_edge[0][1] - relevant_edge[1][1],
578
+ )
579
+ else:
580
+ assert 0
581
+
582
+ vec_degree = x_axis_angle(vec)
583
+ relevant_edges_degree.append(vec_degree)
584
+
585
+ vec_from_current_point_to_last_point_degree = None
586
+ for relevant_edge_ind, relevant_edge in enumerate(relevant_edges):
587
+ if relevant_edge == (current_point, last_point):
588
+ vec_from_current_point_to_last_point_degree = relevant_edges_degree[relevant_edge_ind]
589
+ relevant_edges.remove(relevant_edge)
590
+ relevant_edges_degree.remove(vec_from_current_point_to_last_point_degree)
591
+ elif relevant_edge == (last_point, current_point):
592
+ vec_from_current_point_to_last_point_degree = relevant_edges_degree[relevant_edge_ind]
593
+ relevant_edges.remove(relevant_edge)
594
+ relevant_edges_degree.remove(vec_from_current_point_to_last_point_degree)
595
+ else:
596
+ continue
597
+
598
+ rotate_deltas_counterclockwise = []
599
+ interior_angles = []
600
+ for relevant_edge_degree in relevant_edges_degree:
601
+ rotate_delta = rotate_degree_counterclockwise_from_counter_degree(
602
+ vec_from_current_point_to_last_point_degree, relevant_edge_degree
603
+ )
604
+ rotate_deltas_counterclockwise.append(rotate_delta)
605
+ interior_angles.append((relevant_edge_degree, vec_from_current_point_to_last_point_degree))
606
+ # print(rotate_deltas_counterclockwise)
607
+ max_rotate_index = rotate_deltas_counterclockwise.index(max(rotate_deltas_counterclockwise))
608
+ interior_angle_counterclockwise = interior_angles[max_rotate_index]
609
+ current_point_semantic = [
610
+ current_point[3],
611
+ current_point[2],
612
+ current_point[5],
613
+ current_point[4],
614
+ ]
615
+ interior_angle_counterclockwise_degree_smaller = min(interior_angle_counterclockwise)
616
+ interior_angle_counterclockwise_degree_bigger = max(interior_angle_counterclockwise)
617
+ quadrant_smaller_to_bigger_counterclockwise = get_quadrant(
618
+ (
619
+ interior_angle_counterclockwise_degree_smaller,
620
+ interior_angle_counterclockwise_degree_bigger,
621
+ )
622
+ )
623
+ if interior_angle_counterclockwise.index(interior_angle_counterclockwise_degree_smaller) == 0:
624
+ pass
625
+ elif (
626
+ interior_angle_counterclockwise.index(interior_angle_counterclockwise_degree_smaller) == 1
627
+ ):
628
+ quadrant_smaller_to_bigger_counterclockwise = (
629
+ 90 - quadrant_smaller_to_bigger_counterclockwise[0],
630
+ 90 - quadrant_smaller_to_bigger_counterclockwise[1],
631
+ 90 - quadrant_smaller_to_bigger_counterclockwise[2],
632
+ 90 - quadrant_smaller_to_bigger_counterclockwise[3],
633
+ )
634
+ else:
635
+ assert 0
636
+ current_point_semantic_valid = []
637
+ for qd, seman in enumerate(current_point_semantic):
638
+ if 1:
639
+ current_point_semantic_valid.append(seman)
640
+ else:
641
+ current_point_semantic_valid.append(-1)
642
+ simple_cycle_semantics.append(current_point_semantic_valid)
643
+
644
+ max_rotate_edge = relevant_edges[max_rotate_index]
645
+ if max_rotate_edge[0] == current_point:
646
+ next_point = max_rotate_edge[1]
647
+ next_point_number = d[next_point]
648
+ elif max_rotate_edge[1] == current_point:
649
+ next_point = max_rotate_edge[0]
650
+ next_point_number = d[next_point]
651
+ else:
652
+ assert 0
653
+
654
+ last_point = current_point
655
+ current_point = next_point
656
+ current_point_number = next_point_number
657
+ simple_cycle.append(current_point)
658
+ simple_cycle_number.append(current_point_number)
659
+
660
+ for point_number_ind, point_number in enumerate(simple_cycle_number):
661
+ if point_number_ind < len(simple_cycle_number) - 1:
662
+ edge_number = (point_number, simple_cycle_number[point_number_ind + 1])
663
+ if edge_number[0] < edge_number[1]:
664
+ if (
665
+ d_rev[edge_number[0]],
666
+ d_rev[edge_number[1]],
667
+ ) in output_edges_c_copy_for_traversing:
668
+ output_edges_c_copy_for_traversing.remove(
669
+ (d_rev[edge_number[0]], d_rev[edge_number[1]])
670
+ )
671
+ elif (
672
+ d_rev[edge_number[1]],
673
+ d_rev[edge_number[0]],
674
+ ) in output_edges_c_copy_for_traversing:
675
+ output_edges_c_copy_for_traversing.remove(
676
+ (d_rev[edge_number[1]], d_rev[edge_number[0]])
677
+ )
678
+
679
+ simple_cycle.pop(-1)
680
+ simple_cycle_number.pop(-1)
681
+ polygon_counterclockwise = [(int(p[0]), -int(p[1])) for p in simple_cycle]
682
+ polygon_counterclockwise.pop(-1)
683
+ if poly_area(polygon_counterclockwise) > 0:
684
+ simple_cycles_c.append(simple_cycle)
685
+ simple_cycles_number_c.append(simple_cycle_number)
686
+ semantic_result = {}
687
+ for semantic_label in range(0, 13):
688
+ semantic_result[semantic_label] = 0
689
+ for everypoint_semantic in simple_cycle_semantics:
690
+ for _ in range(0, 13):
691
+ if _ in everypoint_semantic:
692
+ semantic_result[_] += 1
693
+ del semantic_result[11]
694
+ del semantic_result[12]
695
+
696
+ this_cycle_semantic = sorted(semantic_result.items(), key=lambda d: d[1], reverse=True)
697
+ this_cycle_result = None
698
+ if this_cycle_semantic[0][1] > this_cycle_semantic[1][1]:
699
+ this_cycle_result = this_cycle_semantic[0][0]
700
+ else:
701
+ this_cycle_results = [
702
+ i[0] for i in this_cycle_semantic if i[1] == this_cycle_semantic[0][1]
703
+ ]
704
+ this_cycle_result = this_cycle_results[random.randint(0, len(this_cycle_results) - 1)]
705
+ simple_cycle_semantics_c.append(this_cycle_result)
706
+
707
+ simple_cycles.extend(simple_cycles_c)
708
+ simple_cycles_number.extend(simple_cycles_number_c)
709
+ simple_cycles_semantics.extend(simple_cycle_semantics_c)
710
+
711
+ return d_rev, simple_cycles, simple_cycles_semantics
712
+
713
+
714
+ def get_cycle_basis(best_result):
715
+ output_points, output_edges = get_results(best_result)
716
+ output_points = copy.deepcopy(output_points)
717
+ output_edges = copy.deepcopy(output_edges)
718
+
719
+ d = {}
720
+ for output_point_index, output_point in enumerate(output_points):
721
+ d[output_point] = output_point_index
722
+ d_rev = {}
723
+ for output_point_index, output_point in enumerate(output_points):
724
+ d_rev[output_point_index] = output_point
725
+ es = []
726
+ for output_edge in output_edges:
727
+ es.append((d[output_edge[0]], d[output_edge[1]]))
728
+
729
+ G = nx.Graph()
730
+ for e in es:
731
+ G.add_edge(e[0], e[1])
732
+
733
+ simple_cycles = []
734
+ simple_cycles_number = []
735
+ bridges = list(nx.bridges(G))
736
+ for b in bridges:
737
+ if (d_rev[b[0]], d_rev[b[1]]) in output_edges:
738
+ output_edges.remove((d_rev[b[0]], d_rev[b[1]]))
739
+ es.remove((b[0], b[1]))
740
+ G.remove_edge(b[0], b[1])
741
+ if (d_rev[b[1]], d_rev[b[0]]) in output_edges:
742
+ output_edges.remove((d_rev[b[1]], d_rev[b[0]]))
743
+ es.remove((b[1], b[0]))
744
+ G.remove_edge(b[1], b[0])
745
+ connected_components = list(nx.connected_components(G))
746
+ for c in connected_components:
747
+ if len(c) == 1:
748
+ pass
749
+ else:
750
+ simple_cycles_c = []
751
+ simple_cycles_number_c = []
752
+ output_edges_c = [e for e in output_edges if d[e[0]] in c or d[e[1]] in c]
753
+ output_edges_c_copy_for_traversing = copy.deepcopy(output_edges_c)
754
+
755
+ for edge_c in output_edges_c:
756
+ if edge_c not in output_edges_c_copy_for_traversing:
757
+ pass
758
+ else:
759
+ simple_cycle = []
760
+ simple_cycle_number = []
761
+ point1 = edge_c[0]
762
+ point2 = edge_c[1]
763
+ point1_number = d[point1]
764
+ point2_number = d[point2]
765
+
766
+ if point1_number < point2_number:
767
+ initial_point = point1
768
+ initial_point_number = point1_number
769
+ current_point = point2
770
+ current_point_number = point2_number
771
+ else:
772
+ initial_point = point2
773
+ initial_point_number = point2_number
774
+ current_point = point1
775
+ current_point_number = point1_number
776
+
777
+ simple_cycle.append(initial_point)
778
+ simple_cycle_number.append(initial_point_number)
779
+ simple_cycle.append(current_point)
780
+ simple_cycle_number.append(current_point_number)
781
+
782
+ last_point = initial_point
783
+ next_initial_point = copy.deepcopy(current_point)
784
+ next_point = None
785
+
786
+ while next_point != next_initial_point:
787
+ relevant_edges = []
788
+ for edge in output_edges_c:
789
+ if edge[0] == current_point or edge[1] == current_point:
790
+ relevant_edges.append(edge)
791
+
792
+ relevant_edges_degree = []
793
+ for relevant_edge in relevant_edges:
794
+ vec = None
795
+ if relevant_edge[0] == current_point:
796
+ vec = (
797
+ relevant_edge[1][0] - relevant_edge[0][0],
798
+ relevant_edge[1][1] - relevant_edge[0][1],
799
+ )
800
+ elif relevant_edge[1] == current_point:
801
+ vec = (
802
+ relevant_edge[0][0] - relevant_edge[1][0],
803
+ relevant_edge[0][1] - relevant_edge[1][1],
804
+ )
805
+ else:
806
+ assert 0
807
+ vec_degree = x_axis_angle(vec)
808
+ relevant_edges_degree.append(vec_degree)
809
+
810
+ vec_from_current_point_to_last_point_degree = None
811
+ for relevant_edge_ind, relevant_edge in enumerate(relevant_edges):
812
+ if relevant_edge == (current_point, last_point):
813
+ vec_from_current_point_to_last_point_degree = relevant_edges_degree[relevant_edge_ind]
814
+ relevant_edges.remove(relevant_edge)
815
+ relevant_edges_degree.remove(vec_from_current_point_to_last_point_degree)
816
+ elif relevant_edge == (last_point, current_point):
817
+ vec_from_current_point_to_last_point_degree = relevant_edges_degree[relevant_edge_ind]
818
+ relevant_edges.remove(relevant_edge)
819
+ relevant_edges_degree.remove(vec_from_current_point_to_last_point_degree)
820
+ else:
821
+ continue
822
+
823
+ rotate_deltas_counterclockwise = []
824
+ for relevant_edge_degree in relevant_edges_degree:
825
+ rotate_delta = rotate_degree_counterclockwise_from_counter_degree(
826
+ vec_from_current_point_to_last_point_degree, relevant_edge_degree
827
+ )
828
+ rotate_deltas_counterclockwise.append(rotate_delta)
829
+
830
+ max_rotate_index = rotate_deltas_counterclockwise.index(max(rotate_deltas_counterclockwise))
831
+ max_rotate_edge = relevant_edges[max_rotate_index]
832
+
833
+ if max_rotate_edge[0] == current_point:
834
+ next_point = max_rotate_edge[1]
835
+ next_point_number = d[next_point]
836
+ elif max_rotate_edge[1] == current_point:
837
+ next_point = max_rotate_edge[0]
838
+ next_point_number = d[next_point]
839
+ else:
840
+ assert 0
841
+
842
+ last_point = current_point
843
+ current_point = next_point
844
+ current_point_number = next_point_number
845
+ simple_cycle.append(current_point)
846
+ simple_cycle_number.append(current_point_number)
847
+
848
+ for point_number_ind, point_number in enumerate(simple_cycle_number):
849
+ if point_number_ind < len(simple_cycle_number) - 1:
850
+ edge_number = (point_number, simple_cycle_number[point_number_ind + 1])
851
+ if edge_number[0] < edge_number[1]:
852
+ if (
853
+ d_rev[edge_number[0]],
854
+ d_rev[edge_number[1]],
855
+ ) in output_edges_c_copy_for_traversing:
856
+ output_edges_c_copy_for_traversing.remove(
857
+ (d_rev[edge_number[0]], d_rev[edge_number[1]])
858
+ )
859
+ elif (
860
+ d_rev[edge_number[1]],
861
+ d_rev[edge_number[0]],
862
+ ) in output_edges_c_copy_for_traversing:
863
+ output_edges_c_copy_for_traversing.remove(
864
+ (d_rev[edge_number[1]], d_rev[edge_number[0]])
865
+ )
866
+
867
+ simple_cycle.pop(-1)
868
+ simple_cycle_number.pop(-1)
869
+
870
+ polygon_counterclockwise = [(int(p[0]), -int(p[1])) for p in simple_cycle]
871
+ polygon_counterclockwise.pop(-1)
872
+ if poly_area(polygon_counterclockwise) > 0:
873
+ simple_cycles_c.append(simple_cycle)
874
+ simple_cycles_number_c.append(simple_cycle_number)
875
+
876
+ simple_cycles.extend(simple_cycles_c)
877
+ simple_cycles_number.extend(simple_cycles_number_c)
878
+
879
+ return d_rev, simple_cycles, simple_cycles_number
data_preprocess/raster2graph/util/image_id_dict.py ADDED
The diff for this file is too large to render. See raw diff
 
data_preprocess/raster2graph/util/math_utils.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ def clip(number, _min, _max):
2
+ if number <= _min:
3
+ return _min
4
+ elif number >= _max:
5
+ return _max
6
+ else:
7
+ return number
data_preprocess/raster2graph/util/mean_std.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ mean = [0.920, 0.913, 0.891]
2
+ std = [0.214, 0.216, 0.228]
data_preprocess/raster2graph/util/metric_utils.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+
4
+ from shapely.geometry import Polygon
5
+ from util.geom_utils import poly_iou
6
+
7
+
8
+ def calculate_AP(valid_results, ground_truths, confidence_final):
9
+ ground_truths_copy = copy.deepcopy(ground_truths)
10
+ all_preds = []
11
+ for image_id, image_pred in valid_results.items():
12
+ for i in range(len(image_pred["points"])):
13
+ pred = {}
14
+ pred["score"] = image_pred["scores"][i].item()
15
+ pred["point"] = tuple(image_pred["points"][i].tolist())
16
+ pred["size"] = tuple(image_pred["size"].tolist())
17
+ pred["image_id"] = image_id.item()
18
+ all_preds.append(pred)
19
+ all_preds = sorted(all_preds, key=lambda x: x["score"], reverse=True)
20
+
21
+ all_preds = [pred for pred in all_preds if pred["score"] > confidence_final]
22
+
23
+ all_metrics = []
24
+ for n in range(1, len(all_preds) + 1):
25
+ ground_truths = copy.deepcopy(ground_truths_copy)
26
+
27
+ sub_preds = all_preds[0:n]
28
+
29
+ TP = 0
30
+ FP = 0
31
+ FN = 0
32
+ for pred in sub_preds:
33
+ pred_point = pred["point"]
34
+ img_size = (pred["size"][1], pred["size"][0])
35
+ img_id = pred["image_id"]
36
+ dist_threshold = (img_size[0] * 0.01, img_size[1] * 0.01)
37
+ gt = [tuple(gt_point) for gt_point in ground_truths[img_id]["points"].tolist()]
38
+ gt_copy = copy.deepcopy(gt)
39
+ euc_dists = {}
40
+ dists = {}
41
+ for gt_point in gt_copy:
42
+ if gt_point[2] == 0:
43
+ dist = (abs(pred_point[0] - gt_point[0]), abs(pred_point[1] - gt_point[1]))
44
+ euc_dist = math.sqrt(dist[0] ** 2 + dist[1] ** 2)
45
+ euc_dists[gt_point] = euc_dist
46
+ dists[gt_point] = dist
47
+ euc_dists = sorted(euc_dists.items(), key=lambda x: x[1])
48
+ if len(euc_dists) == 0:
49
+ FP += 1
50
+ continue
51
+ nearest_gt_point = euc_dists[0][0]
52
+ min_dist = dists[nearest_gt_point]
53
+ if min_dist[0] < dist_threshold[0] and min_dist[1] < dist_threshold[1]:
54
+ gtip = ground_truths[img_id]["points"]
55
+ for i, p in enumerate(gtip):
56
+ if (
57
+ p[0].item() == nearest_gt_point[0]
58
+ and p[1].item() == nearest_gt_point[1]
59
+ and p[2].item() == nearest_gt_point[2]
60
+ ):
61
+ # print('qqq', p, nearest_gt_point)
62
+ gtip[i, 2] = 1
63
+ break
64
+ ground_truths[img_id]["points"] = gtip
65
+ # print('rrr', ground_truths[img_id]['points'])
66
+ TP += 1
67
+ continue
68
+ FP += 1
69
+ for img_id, points in ground_truths.items():
70
+ points = points["points"]
71
+ for point in points:
72
+ if point[2] == 0:
73
+ FN += 1
74
+ precision = TP / (TP + FP)
75
+ recall = TP / (TP + FN)
76
+ # print(n, TP, FP, FN, precision, recall)
77
+ all_metrics.append((precision, recall))
78
+
79
+ all_metrics = sorted(all_metrics, key=lambda x: (x[1], x[0]))
80
+ p_r_curve_points = {}
81
+ for point in all_metrics:
82
+ p_r_curve_points[point[1]] = point[0]
83
+ p_r_curve_points[0] = 1
84
+ p_r_curve_points = sorted(p_r_curve_points.items(), key=lambda d: d[0])
85
+ AP = 0
86
+ for i, rp in enumerate(p_r_curve_points):
87
+ r = rp[0]
88
+ p = rp[1]
89
+ if i > 0:
90
+ small_rectangular_area = (r - p_r_curve_points[i - 1][0]) * p
91
+ AP += small_rectangular_area
92
+
93
+ return AP
94
+
95
+
96
+ def get_results(best_result):
97
+ if 1:
98
+ preds = best_result[2]
99
+ output_points = []
100
+ output_edges = []
101
+ for triplet in preds:
102
+ this_preds = triplet[0]
103
+ last_edges = triplet[1]
104
+ this_edges = triplet[2]
105
+ for this_pred in this_preds:
106
+ point = tuple(this_pred["points"].int().tolist())
107
+ output_points.append(point)
108
+ for last_edge in last_edges:
109
+ point1 = tuple(last_edge[0]["points"].int().tolist())
110
+ point2 = tuple(last_edge[1]["points"].int().tolist())
111
+ edge = (point1, point2)
112
+ output_edges.append(edge)
113
+ for this_edge in this_edges:
114
+ point1 = tuple(this_edge[0]["points"].int().tolist())
115
+ point2 = tuple(this_edge[1]["points"].int().tolist())
116
+ edge = (point1, point2)
117
+ output_edges.append(edge)
118
+ return output_points, output_edges
119
+
120
+
121
+ def get_results_visual(best_result):
122
+ if 1:
123
+ preds = best_result[2]
124
+ output_points = []
125
+ output_edges = []
126
+ for layer_index, triplet in enumerate(preds):
127
+ this_preds = triplet[0]
128
+ last_edges = triplet[1]
129
+ this_edges = triplet[2]
130
+ for this_pred in this_preds:
131
+ point = tuple(this_pred["points"].int().tolist())
132
+ output_points.append([layer_index, point])
133
+ for last_edge in last_edges:
134
+ point1 = tuple(last_edge[0]["points"].int().tolist())
135
+ point2 = tuple(last_edge[1]["points"].int().tolist())
136
+ edge = (point1, point2)
137
+ output_edges.append([layer_index, edge])
138
+ for this_edge in this_edges:
139
+ point1 = tuple(this_edge[0]["points"].int().tolist())
140
+ point2 = tuple(this_edge[1]["points"].int().tolist())
141
+ edge = (point1, point2)
142
+ output_edges.append([layer_index, edge])
143
+ return output_points, output_edges, len(preds)
144
+
145
+
146
+ def get_results_float_with_semantic(best_result):
147
+ preds = best_result[2]
148
+ output_points = []
149
+ output_edges = []
150
+ for triplet in preds:
151
+ this_preds = triplet[0]
152
+ last_edges = triplet[1]
153
+ this_edges = triplet[2]
154
+ for this_pred in this_preds:
155
+ point = (
156
+ this_pred["points"].tolist()[0],
157
+ this_pred["points"].tolist()[1],
158
+ this_pred["semantic_left_up"].item(),
159
+ this_pred["semantic_right_up"].item(),
160
+ this_pred["semantic_right_down"].item(),
161
+ this_pred["semantic_left_down"].item(),
162
+ )
163
+ output_points.append(point)
164
+ for last_edge in last_edges:
165
+ point1 = (
166
+ last_edge[0]["points"].tolist()[0],
167
+ last_edge[0]["points"].tolist()[1],
168
+ last_edge[0]["semantic_left_up"].item(),
169
+ last_edge[0]["semantic_right_up"].item(),
170
+ last_edge[0]["semantic_right_down"].item(),
171
+ last_edge[0]["semantic_left_down"].item(),
172
+ )
173
+ point2 = (
174
+ last_edge[1]["points"].tolist()[0],
175
+ last_edge[1]["points"].tolist()[1],
176
+ last_edge[1]["semantic_left_up"].item(),
177
+ last_edge[1]["semantic_right_up"].item(),
178
+ last_edge[1]["semantic_right_down"].item(),
179
+ last_edge[1]["semantic_left_down"].item(),
180
+ )
181
+ edge = (point1, point2)
182
+ output_edges.append(edge)
183
+ for this_edge in this_edges:
184
+ point1 = (
185
+ this_edge[0]["points"].tolist()[0],
186
+ this_edge[0]["points"].tolist()[1],
187
+ this_edge[0]["semantic_left_up"].item(),
188
+ this_edge[0]["semantic_right_up"].item(),
189
+ this_edge[0]["semantic_right_down"].item(),
190
+ this_edge[0]["semantic_left_down"].item(),
191
+ )
192
+ point2 = (
193
+ this_edge[1]["points"].tolist()[0],
194
+ this_edge[1]["points"].tolist()[1],
195
+ this_edge[1]["semantic_left_up"].item(),
196
+ this_edge[1]["semantic_right_up"].item(),
197
+ this_edge[1]["semantic_right_down"].item(),
198
+ this_edge[1]["semantic_left_down"].item(),
199
+ )
200
+ edge = (point1, point2)
201
+ output_edges.append(edge)
202
+
203
+ return output_points, output_edges
204
+
205
+
206
+ def calculate_single_sample(
207
+ best_result, graph, target_d_rev, target_simple_cycles, target_results, d_rev, simple_cycles, results
208
+ ):
209
+ output_points, output_edges = get_results(best_result)
210
+ gt_points = [k for k, v in graph.items()]
211
+ gt_edges = []
212
+ for k, v in graph.items():
213
+ for adj in v:
214
+ if adj != (-1, -1):
215
+ gt_edge = (k, adj)
216
+ if (adj, k) not in gt_edges:
217
+ gt_edges.append(gt_edge)
218
+
219
+ points_TP = 0
220
+ points_FP = 0
221
+ points_FN = 0
222
+ dist_error_x = 0
223
+ dist_error_y = 0
224
+ dist_error_l2 = 0
225
+ gt_points_copy = copy.deepcopy(gt_points)
226
+ threshold = 5
227
+ for output_point in output_points:
228
+ matched = False
229
+ for gt_point in gt_points:
230
+ if (abs(output_point[0] - gt_point[0]) <= threshold) and (abs(output_point[1] - gt_point[1]) <= threshold):
231
+ if gt_point in gt_points_copy:
232
+ points_TP += 1
233
+ dist_error_x += abs(output_point[0] - gt_point[0])
234
+ dist_error_y += abs(output_point[1] - gt_point[1])
235
+ dist_error_l2 += (
236
+ abs(output_point[0] - gt_point[0]) ** 2 + abs(output_point[1] - gt_point[1]) ** 2
237
+ ) ** 0.5
238
+ matched = True
239
+ gt_points_copy.remove(gt_point)
240
+ break
241
+ if not matched:
242
+ points_FP += 1
243
+ points_FN = len(gt_points) - points_TP
244
+
245
+ edges_TP = 0
246
+ edges_FP = 0
247
+ edges_FN = 0
248
+ gt_edges_copy = copy.deepcopy(gt_edges)
249
+ threshold = 5
250
+ for output_edge in output_edges:
251
+ matched = False
252
+ for gt_edge in gt_edges:
253
+ if (
254
+ (
255
+ (abs(output_edge[0][0] - gt_edge[0][0]) <= threshold)
256
+ and (abs(output_edge[0][1] - gt_edge[0][1]) <= threshold)
257
+ )
258
+ and (
259
+ (abs(output_edge[1][0] - gt_edge[1][0]) <= threshold)
260
+ and (abs(output_edge[1][1] - gt_edge[1][1]) <= threshold)
261
+ )
262
+ ) or (
263
+ (
264
+ (abs(output_edge[0][0] - gt_edge[1][0]) <= threshold)
265
+ and (abs(output_edge[0][1] - gt_edge[1][1]) <= threshold)
266
+ )
267
+ and (
268
+ (abs(output_edge[1][0] - gt_edge[0][0]) <= threshold)
269
+ and (abs(output_edge[1][1] - gt_edge[0][1]) <= threshold)
270
+ )
271
+ ):
272
+ if gt_edge in gt_edges_copy:
273
+ edges_TP += 1
274
+ matched = True
275
+ gt_edges_copy.remove(gt_edge)
276
+ break
277
+ if not matched:
278
+ edges_FP += 1
279
+ edges_FN = len(gt_edges) - edges_TP
280
+
281
+ regions_TP = 0
282
+ regions_FP = 0
283
+ regions_FN = 0
284
+ rooms_TP = 0
285
+ rooms_FP = 0
286
+ rooms_FN = 0
287
+ gt_regions = []
288
+ output_regions = []
289
+
290
+ for target_simple_cycle in target_simple_cycles:
291
+ target_polyg = [(point_i[0], point_i[1]) for point_i in target_simple_cycle]
292
+ gt_regions.append(target_polyg)
293
+
294
+ for simple_cycle in simple_cycles:
295
+ polyg = [(point_i[0], point_i[1]) for point_i in simple_cycle]
296
+ polyg.pop(-1)
297
+ output_regions.append(polyg)
298
+ gt_regions_copy = copy.deepcopy(gt_regions)
299
+ iou_threshold = 0.7
300
+ for output_region_i, output_region in enumerate(output_regions):
301
+ matched = False
302
+ for gt_region_i, gt_region in enumerate(gt_regions):
303
+ if poly_iou(Polygon(gt_region), Polygon(output_region)) >= iou_threshold:
304
+ if gt_region in gt_regions_copy:
305
+ regions_TP += 1
306
+ if target_results[gt_region_i] == results[output_region_i]:
307
+ rooms_TP += 1
308
+ else:
309
+ rooms_FP += 1
310
+ matched = True
311
+ gt_regions_copy.remove(gt_region)
312
+ break
313
+ if not matched:
314
+ regions_FP += 1
315
+ rooms_FP += 1
316
+ regions_FN = len(gt_regions) - regions_TP
317
+ rooms_FN = len(gt_regions) - rooms_TP
318
+ # print(regions_TP, regions_FP, regions_FN)
319
+ # print(rooms_TP, rooms_FP, rooms_FN)
320
+
321
+ dist_error = (0, 0, 0)
322
+ if points_TP > 0:
323
+ dist_error = (dist_error_x, dist_error_y, dist_error_l2)
324
+ return (
325
+ points_TP,
326
+ points_FP,
327
+ points_FN,
328
+ edges_TP,
329
+ edges_FP,
330
+ edges_FN,
331
+ dist_error,
332
+ regions_TP,
333
+ regions_FP,
334
+ regions_FN,
335
+ rooms_TP,
336
+ rooms_FP,
337
+ rooms_FN,
338
+ )
data_preprocess/raster2graph/util/semantics_dict.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ semantics_dict = {
2
+ "living_room": 1,
3
+ "kitchen": 2,
4
+ "bedroom": 3,
5
+ "bathroom": 4,
6
+ "restroom": 5,
7
+ "balcony": 6,
8
+ "closet": 7,
9
+ "corridor": 8,
10
+ "washing_room": 9,
11
+ "PS": 10,
12
+ "outside": 11,
13
+ "wall": 12,
14
+ "no_type": 0,
15
+ }
16
+ semantics_dict_rev = {
17
+ 0: "no_type",
18
+ 1: "living_room",
19
+ 2: "kitchen",
20
+ 3: "bedroom",
21
+ 4: "bathroom",
22
+ 5: "restroom",
23
+ 6: "balcony",
24
+ 7: "closet",
25
+ 8: "corridor",
26
+ 9: "washing_room",
27
+ 10: "PS",
28
+ 11: "outside",
29
+ 12: "wall",
30
+ }
31
+ semantics_dict_color = {
32
+ "living_room": (0, 0, 220),
33
+ "kitchen": (0, 220, 220),
34
+ "bedroom": (0, 220, 0),
35
+ "bathroom": (220, 220, 0),
36
+ "restroom": (220, 0, 0),
37
+ "balcony": (220, 0, 220),
38
+ "closet": (110, 0, 110),
39
+ "corridor": (110, 0, 0),
40
+ "washing_room": (0, 0, 110),
41
+ "PS": (0, 110, 110),
42
+ "outside": (0, 0, 0),
43
+ "wall": (110, 110, 110),
44
+ "no_type": (20, 20, 20),
45
+ }
data_preprocess/stru3d/PointCloudReaderPanorama.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import open3d as o3d
6
+
7
+ NUM_SECTIONS = -1
8
+
9
+
10
+ class PointCloudReaderPanorama:
11
+ def __init__(self, path, resolution="full", random_level=0, generate_color=False, generate_normal=False):
12
+ self.path = path
13
+ self.random_level = random_level
14
+ self.resolution = resolution
15
+ self.generate_color = generate_color
16
+ self.generate_normal = generate_normal
17
+ sections = [p for p in os.listdir(os.path.join(path, "2D_rendering"))]
18
+ self.depth_paths = [
19
+ os.path.join(*[path, "2D_rendering", p, "panorama", self.resolution, "depth.png"]) for p in sections
20
+ ]
21
+ self.rgb_paths = [
22
+ os.path.join(*[path, "2D_rendering", p, "panorama", self.resolution, "rgb_coldlight.png"])
23
+ for p in sections
24
+ ]
25
+ self.normal_paths = [
26
+ os.path.join(*[path, "2D_rendering", p, "panorama", self.resolution, "normal.png"]) for p in sections
27
+ ]
28
+ self.camera_paths = [os.path.join(*[path, "2D_rendering", p, "panorama", "camera_xyz.txt"]) for p in sections]
29
+ self.camera_centers = self.read_camera_center()
30
+ self.point_cloud = self.generate_point_cloud(
31
+ self.random_level, color=self.generate_color, normal=self.generate_normal
32
+ )
33
+
34
+ def read_camera_center(self):
35
+ camera_centers = []
36
+ for i in range(len(self.camera_paths)):
37
+ with open(self.camera_paths[i], "r") as f:
38
+ line = f.readline()
39
+ center = list(map(float, line.strip().split(" ")))
40
+ camera_centers.append(np.asarray([center[0], center[1], center[2]]))
41
+ return camera_centers
42
+
43
+ def generate_point_cloud(self, random_level=0, color=False, normal=False):
44
+ coords = []
45
+ colors = []
46
+ points = {}
47
+ # normals = []
48
+
49
+ # Getting Coordinates
50
+ for i in range(len(self.depth_paths)):
51
+ depth_img = cv2.imread(self.depth_paths[i], cv2.IMREAD_ANYDEPTH | cv2.IMREAD_ANYCOLOR)
52
+ x_tick = 180.0 / depth_img.shape[0]
53
+ y_tick = 360.0 / depth_img.shape[1]
54
+
55
+ rgb_img = cv2.imread(self.rgb_paths[i])
56
+ rgb_img = cv2.cvtColor(rgb_img, code=cv2.COLOR_BGR2RGB)
57
+ # normal_img = cv2.imread(self.normal_paths[i])
58
+
59
+ for x in range(0, depth_img.shape[0]):
60
+ for y in range(0, depth_img.shape[1]):
61
+ # need 90 - -09
62
+ alpha = 90 - (x * x_tick)
63
+ beta = y * y_tick - 180
64
+
65
+ depth = depth_img[x, y] + np.random.random() * random_level
66
+
67
+ if depth > 500.0:
68
+ z_offset = depth * np.sin(np.deg2rad(alpha))
69
+ xy_offset = depth * np.cos(np.deg2rad(alpha))
70
+ x_offset = xy_offset * np.sin(np.deg2rad(beta))
71
+ y_offset = xy_offset * np.cos(np.deg2rad(beta))
72
+ point = np.asarray([x_offset, y_offset, z_offset])
73
+ coords.append(point + self.camera_centers[i])
74
+ colors.append(rgb_img[x, y])
75
+ # normals.append(normalize(normal_img[x, y].reshape(-1, 1)).ravel())
76
+
77
+ coords = np.asarray(coords)
78
+ colors = np.asarray(colors) / 255.0
79
+ # normals = np.asarray(normals)
80
+
81
+ coords[:, :2] = np.round(coords[:, :2] / 10) * 10.0
82
+ coords[:, 2] = np.round(coords[:, 2] / 100) * 100.0
83
+ unique_coords, unique_ind = np.unique(coords, return_index=True, axis=0)
84
+
85
+ coords = coords[unique_ind]
86
+ colors = colors[unique_ind]
87
+ # normals = normals[unique_ind]
88
+
89
+ points["coords"] = coords
90
+ points["colors"] = colors
91
+ # points['normals'] = normals
92
+
93
+ print("Pointcloud size:", points["coords"].shape[0])
94
+ return points
95
+
96
+ def get_point_cloud(self):
97
+ return self.point_cloud
98
+
99
+ def generate_density(self, width=256, height=256):
100
+
101
+ ps = self.point_cloud["coords"] * -1
102
+ ps[:, 0] *= -1
103
+ ps[:, 1] *= -1
104
+
105
+ pcd = o3d.geometry.PointCloud()
106
+ pcd.points = o3d.utility.Vector3dVector(ps)
107
+ pcd.estimate_normals()
108
+
109
+ # zs = np.round(ps[:,2] / 100) * 100
110
+ # zs, zs_ind = np.unique(zs, return_index=True, axis=0)
111
+ # ps_ind = ps[:, :2] ==
112
+ # print("Generate density...")
113
+
114
+ image_res = np.array((width, height))
115
+
116
+ max_coords = np.max(ps, axis=0)
117
+ min_coords = np.min(ps, axis=0)
118
+ max_m_min = max_coords - min_coords
119
+
120
+ max_coords = max_coords + 0.1 * max_m_min
121
+ min_coords = min_coords - 0.1 * max_m_min
122
+
123
+ normalization_dict = {}
124
+ normalization_dict["min_coords"] = min_coords
125
+ normalization_dict["max_coords"] = max_coords
126
+ normalization_dict["image_res"] = image_res
127
+
128
+ # coordinates = np.round(points[:, :2] / max_coordinates[None,:2] * image_res[None])
129
+ coordinates = np.round(
130
+ (ps[:, :2] - min_coords[None, :2]) / (max_coords[None, :2] - min_coords[None, :2]) * image_res[None]
131
+ )
132
+ coordinates = np.minimum(np.maximum(coordinates, np.zeros_like(image_res)), image_res - 1)
133
+
134
+ density = np.zeros((height, width), dtype=np.float32)
135
+
136
+ unique_coordinates, counts = np.unique(coordinates, return_counts=True, axis=0)
137
+ # print(np.unique(counts))
138
+ # counts = np.minimum(counts, 1e2)
139
+
140
+ unique_coordinates = unique_coordinates.astype(np.int32)
141
+
142
+ density[unique_coordinates[:, 1], unique_coordinates[:, 0]] = counts
143
+ density = density / np.max(density)
144
+ # print(np.unique(density))
145
+
146
+ normals = np.array(pcd.normals)
147
+ normals_map = np.zeros((density.shape[0], density.shape[1], 3))
148
+
149
+ import time
150
+
151
+ start_time = time.time()
152
+ for i, unique_coord in enumerate(unique_coordinates):
153
+ # print(normals[unique_ind])
154
+ normals_indcs = np.argwhere(np.all(coordinates[::10] == unique_coord, axis=1))[:, 0]
155
+ normals_map[unique_coordinates[i, 1], unique_coordinates[i, 0], :] = np.mean(
156
+ normals[::10][normals_indcs, :], axis=0
157
+ )
158
+
159
+ print("Time for normals: ", time.time() - start_time)
160
+
161
+ normals_map = (np.clip(normals_map, 0, 1) * 255).astype(np.uint8)
162
+
163
+ # plt.figure()
164
+ # plt.imshow(normals_map)
165
+ # plt.show()
166
+
167
+ return density, normals_map, normalization_dict
168
+
169
+ def visualize(self, export_path=None):
170
+ pcd = o3d.geometry.PointCloud()
171
+
172
+ points = self.point_cloud["coords"]
173
+
174
+ print(np.max(points, axis=0))
175
+ indices = np.where(points[:, 2] < 2000)
176
+
177
+ points = points[indices]
178
+ points[:, 1] *= -1
179
+ points[:, :] /= 1000
180
+ pcd.points = o3d.utility.Vector3dVector(points)
181
+
182
+ if self.generate_normal:
183
+ normals = self.point_cloud["normals"]
184
+ normals = normals[indices]
185
+ pcd.normals = o3d.utility.Vector3dVector(normals)
186
+ if self.generate_color:
187
+ colors = self.point_cloud["colors"]
188
+ colors = colors[indices]
189
+ pcd.colors = o3d.utility.Vector3dVector(colors)
190
+
191
+ # wireframe_geo_list = visualize_wireframe(annos, vis=False, ret=True)
192
+ # o3d.visualization.draw_geometries([pcd] + wireframe_geo_list)
193
+ # o3d.visualization.draw_geometries([pcd])
194
+
195
+ pcd.estimate_normals()
196
+
197
+ # radii = 0.01
198
+ # mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting(pcd, radii)
199
+
200
+ # alpha = 0.1
201
+ # tetra_mesh, pt_map = o3d.geometry.TetraMesh.create_from_point_cloud(pcd)
202
+ # mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_alpha_shape(pcd, alpha, tetra_mesh, pt_map)
203
+
204
+ o3d.visualization.draw_geometries([pcd])
205
+
206
+ if export_path is not None:
207
+ o3d.io.write_point_cloud(export_path, pcd)
208
+
209
+ # o3d.visualization.draw_geometries([pcd])
210
+
211
+ def export_ply(self, path):
212
+ """
213
+ ply
214
+ format ascii 1.0
215
+ comment Mars model by Paul Bourke
216
+ element vertex 259200
217
+ property float x
218
+ property float y
219
+ property float z
220
+ property uchar r
221
+ property uchar g
222
+ property uchar b
223
+ property float nx
224
+ property float ny
225
+ property float nz
226
+ end_header
227
+ """
228
+ with open(path, "w") as f:
229
+ f.write("ply\n")
230
+ f.write("format ascii 1.0\n")
231
+ f.write("element vertex %d\n" % self.point_cloud["coords"].shape[0])
232
+ f.write("property float x\n")
233
+ f.write("property float y\n")
234
+ f.write("property float z\n")
235
+ if self.generate_color:
236
+ f.write("property uchar red\n")
237
+ f.write("property uchar green\n")
238
+ f.write("property uchar blue\n")
239
+ if self.generate_normal:
240
+ f.write("property float nx\n")
241
+ f.write("property float ny\n")
242
+ f.write("property float nz\n")
243
+ f.write("end_header\n")
244
+ for i in range(self.point_cloud["coords"].shape[0]):
245
+ normal = []
246
+ color = []
247
+ coord = self.point_cloud["coords"][i].tolist()
248
+ if self.generate_color:
249
+ color = list(map(int, (self.point_cloud["colors"][i] * 255).tolist()))
250
+ if self.generate_normal:
251
+ normal = self.point_cloud["normals"][i].tolist()
252
+ data = coord + color + normal
253
+ f.write(" ".join(list(map(str, data))) + "\n")
data_preprocess/stru3d/generate_coco_stru3d.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import sys
5
+
6
+ from stru3d_utils import generate_coco_dict, generate_density, normalize_annotations, parse_floor_plan_polys
7
+ from tqdm import tqdm
8
+
9
+ sys.path.append("../.")
10
+ from common_utils import export_density, read_scene_pc
11
+
12
+ ### Note: Some scenes have missing/wrong annotations. These are the indices that you should additionally exclude
13
+ ### to be consistent with MonteFloor and HEAT:
14
+ invalid_scenes_ids = [
15
+ 76,
16
+ 183,
17
+ 335,
18
+ 491,
19
+ 663,
20
+ 681,
21
+ 703,
22
+ 728,
23
+ 865,
24
+ 936,
25
+ 985,
26
+ 986,
27
+ 1009,
28
+ 1104,
29
+ 1155,
30
+ 1221,
31
+ 1282,
32
+ 1365,
33
+ 1378,
34
+ 1635,
35
+ 1745,
36
+ 1772,
37
+ 1774,
38
+ 1816,
39
+ 1866,
40
+ 2037,
41
+ 2076,
42
+ 2274,
43
+ 2334,
44
+ 2357,
45
+ 2580,
46
+ 2665,
47
+ 2706,
48
+ 2713,
49
+ 2771,
50
+ 2868,
51
+ 3156,
52
+ 3192,
53
+ 3198,
54
+ 3261,
55
+ 3271,
56
+ 3276,
57
+ 3296,
58
+ 3342,
59
+ 3387,
60
+ 3398,
61
+ 3466,
62
+ 3496,
63
+ ]
64
+
65
+ type2id = {
66
+ "living room": 0,
67
+ "kitchen": 1,
68
+ "bedroom": 2,
69
+ "bathroom": 3,
70
+ "balcony": 4,
71
+ "corridor": 5,
72
+ "dining room": 6,
73
+ "study": 7,
74
+ "studio": 8,
75
+ "store room": 9,
76
+ "garden": 10,
77
+ "laundry room": 11,
78
+ "office": 12,
79
+ "basement": 13,
80
+ "garage": 14,
81
+ "undefined": 15,
82
+ "door": 16,
83
+ "window": 17,
84
+ }
85
+
86
+
87
+ def config():
88
+ a = argparse.ArgumentParser(description="Generate coco format data for Structured3D")
89
+ a.add_argument(
90
+ "--data_root", default="Structured3D_panorama", type=str, help="path to raw Structured3D_panorama folder"
91
+ )
92
+ a.add_argument("--output", default="coco_stru3d", type=str, help="path to output folder")
93
+
94
+ args = a.parse_args()
95
+ return args
96
+
97
+
98
+ def main(args):
99
+ data_root = args.data_root
100
+ data_parts = os.listdir(data_root)
101
+
102
+ ### prepare
103
+ outFolder = args.output
104
+ if not os.path.exists(outFolder):
105
+ os.mkdir(outFolder)
106
+
107
+ annotation_outFolder = os.path.join(outFolder, "annotations")
108
+ if not os.path.exists(annotation_outFolder):
109
+ os.mkdir(annotation_outFolder)
110
+
111
+ train_img_folder = os.path.join(outFolder, "train")
112
+ val_img_folder = os.path.join(outFolder, "val")
113
+ test_img_folder = os.path.join(outFolder, "test")
114
+
115
+ for img_folder in [train_img_folder, val_img_folder, test_img_folder]:
116
+ if not os.path.exists(img_folder):
117
+ os.mkdir(img_folder)
118
+
119
+ coco_train_json_path = os.path.join(annotation_outFolder, "train.json")
120
+ coco_val_json_path = os.path.join(annotation_outFolder, "val.json")
121
+ coco_test_json_path = os.path.join(annotation_outFolder, "test.json")
122
+
123
+ coco_train_dict = {"images": [], "annotations": [], "categories": []}
124
+ coco_val_dict = {"images": [], "annotations": [], "categories": []}
125
+ coco_test_dict = {"images": [], "annotations": [], "categories": []}
126
+
127
+ for key, value in type2id.items():
128
+ type_dict = {"supercategory": "room", "id": value, "name": key}
129
+ coco_train_dict["categories"].append(type_dict)
130
+ coco_val_dict["categories"].append(type_dict)
131
+ coco_test_dict["categories"].append(type_dict)
132
+
133
+ ### begin processing
134
+ instance_id = 0
135
+ for part in tqdm(data_parts):
136
+ scenes = os.listdir(os.path.join(data_root, part, "Structured3D"))
137
+ for scene in tqdm(scenes):
138
+ scene_path = os.path.join(data_root, part, "Structured3D", scene)
139
+ scene_id = scene.split("_")[-1]
140
+
141
+ if int(scene_id) in invalid_scenes_ids:
142
+ print("skip {}".format(scene))
143
+ continue
144
+
145
+ # load pre-generated point cloud
146
+ ply_path = os.path.join(scene_path, "point_cloud.ply")
147
+ points = read_scene_pc(ply_path)
148
+ xyz = points[:, :3]
149
+
150
+ ### project point cloud to density map
151
+ density, normalization_dict = generate_density(xyz, width=256, height=256)
152
+
153
+ ### rescale raw annotations
154
+ normalized_annos = normalize_annotations(scene_path, normalization_dict)
155
+
156
+ ### prepare coco dict
157
+ img_id = int(scene_id)
158
+ img_dict = {}
159
+ img_dict["file_name"] = scene_id + ".png"
160
+ img_dict["id"] = img_id
161
+ img_dict["width"] = 256
162
+ img_dict["height"] = 256
163
+
164
+ ### parse annotations
165
+ polys = parse_floor_plan_polys(normalized_annos)
166
+ polygons_list = generate_coco_dict(normalized_annos, polys, instance_id, img_id, ignore_types=["outwall"])
167
+
168
+ instance_id += len(polygons_list)
169
+
170
+ ### train
171
+ if int(scene_id) < 3000:
172
+ coco_train_dict["images"].append(img_dict)
173
+ coco_train_dict["annotations"] += polygons_list
174
+ export_density(density, train_img_folder, scene_id)
175
+
176
+ ### val
177
+ elif int(scene_id) >= 3000 and int(scene_id) < 3250:
178
+ coco_val_dict["images"].append(img_dict)
179
+ coco_val_dict["annotations"] += polygons_list
180
+ export_density(density, val_img_folder, scene_id)
181
+
182
+ ### test
183
+ else:
184
+ coco_test_dict["images"].append(img_dict)
185
+ coco_test_dict["annotations"] += polygons_list
186
+ export_density(density, test_img_folder, scene_id)
187
+
188
+ print(scene_id)
189
+
190
+ with open(coco_train_json_path, "w") as f:
191
+ json.dump(coco_train_dict, f)
192
+ with open(coco_val_json_path, "w") as f:
193
+ json.dump(coco_val_dict, f)
194
+ with open(coco_test_json_path, "w") as f:
195
+ json.dump(coco_test_dict, f)
196
+
197
+
198
+ if __name__ == "__main__":
199
+ main(config())
data_preprocess/stru3d/generate_point_cloud_stru3d.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ from PointCloudReaderPanorama import PointCloudReaderPanorama
5
+ from tqdm import tqdm
6
+
7
+
8
+ def config():
9
+ a = argparse.ArgumentParser(description="Generate point cloud for Structured3D")
10
+ a.add_argument(
11
+ "--data_root", default="Structured3D_panorama", type=str, help="path to raw Structured3D_panorama folder"
12
+ )
13
+ args = a.parse_args()
14
+ return args
15
+
16
+
17
+ def main(args):
18
+ print("Creating point cloud from perspective views...")
19
+ data_root = args.data_root
20
+ data_parts = os.listdir(data_root)
21
+
22
+ for part in tqdm(data_parts):
23
+ scenes = os.listdir(os.path.join(data_root, part, "Structured3D"))
24
+ for scene in tqdm(scenes):
25
+ scene_path = os.path.join(data_root, part, "Structured3D", scene)
26
+ reader = PointCloudReaderPanorama(scene_path, random_level=0, generate_color=True, generate_normal=False)
27
+ save_path = os.path.join(data_root, part, "Structured3D", scene, "point_cloud.ply")
28
+ reader.export_ply(save_path)
29
+
30
+
31
+ if __name__ == "__main__":
32
+ main(config())
data_preprocess/stru3d/stru3d_utils.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This code is an adaptation that uses Structured 3D for the code base.
3
+
4
+ Reference: https://github.com/bertjiazheng/Structured3D
5
+ """
6
+
7
+ import json
8
+ import os
9
+ import sys
10
+
11
+ import numpy as np
12
+ from shapely.geometry import Polygon
13
+
14
+ sys.path.append("../data_preprocess")
15
+ from common_utils import resort_corners
16
+
17
+ type2id = {
18
+ "living room": 0,
19
+ "kitchen": 1,
20
+ "bedroom": 2,
21
+ "bathroom": 3,
22
+ "balcony": 4,
23
+ "corridor": 5,
24
+ "dining room": 6,
25
+ "study": 7,
26
+ "studio": 8,
27
+ "store room": 9,
28
+ "garden": 10,
29
+ "laundry room": 11,
30
+ "office": 12,
31
+ "basement": 13,
32
+ "garage": 14,
33
+ "undefined": 15,
34
+ "door": 16,
35
+ "window": 17,
36
+ }
37
+
38
+
39
+ def generate_density(point_cloud, width=256, height=256):
40
+
41
+ ps = point_cloud * -1
42
+ ps[:, 0] *= -1
43
+ ps[:, 1] *= -1
44
+
45
+ image_res = np.array((width, height))
46
+
47
+ max_coords = np.max(ps, axis=0)
48
+ min_coords = np.min(ps, axis=0)
49
+ max_m_min = max_coords - min_coords
50
+
51
+ max_coords = max_coords + 0.1 * max_m_min
52
+ min_coords = min_coords - 0.1 * max_m_min
53
+
54
+ normalization_dict = {}
55
+ normalization_dict["min_coords"] = min_coords
56
+ normalization_dict["max_coords"] = max_coords
57
+ normalization_dict["image_res"] = image_res
58
+
59
+ # coordinates = np.round(points[:, :2] / max_coordinates[None,:2] * image_res[None])
60
+ coordinates = np.round(
61
+ (ps[:, :2] - min_coords[None, :2]) / (max_coords[None, :2] - min_coords[None, :2]) * image_res[None]
62
+ )
63
+ coordinates = np.minimum(np.maximum(coordinates, np.zeros_like(image_res)), image_res - 1)
64
+
65
+ density = np.zeros((height, width), dtype=np.float32)
66
+
67
+ unique_coordinates, counts = np.unique(coordinates, return_counts=True, axis=0)
68
+ # print(np.unique(counts))
69
+ # counts = np.minimum(counts, 1e2)
70
+
71
+ unique_coordinates = unique_coordinates.astype(np.int32)
72
+
73
+ density[unique_coordinates[:, 1], unique_coordinates[:, 0]] = counts
74
+ density = density / np.max(density)
75
+
76
+ return density, normalization_dict
77
+
78
+
79
+ def normalize_point(point, normalization_dict):
80
+
81
+ min_coords = normalization_dict["min_coords"]
82
+ max_coords = normalization_dict["max_coords"]
83
+ image_res = normalization_dict["image_res"]
84
+
85
+ point_2d = np.round((point[:2] - min_coords[:2]) / (max_coords[:2] - min_coords[:2]) * image_res)
86
+ point_2d = np.minimum(np.maximum(point_2d, np.zeros_like(image_res)), image_res - 1)
87
+
88
+ point[:2] = point_2d.tolist()
89
+
90
+ return point
91
+
92
+
93
+ def normalize_annotations(scene_path, normalization_dict):
94
+ annotation_path = os.path.join(scene_path, "annotation_3d.json")
95
+ with open(annotation_path, "r") as f:
96
+ annotation_json = json.load(f)
97
+
98
+ for line in annotation_json["lines"]:
99
+ point = line["point"]
100
+ point = normalize_point(point, normalization_dict)
101
+ line["point"] = point
102
+
103
+ for junction in annotation_json["junctions"]:
104
+ point = junction["coordinate"]
105
+ point = normalize_point(point, normalization_dict)
106
+ junction["coordinate"] = point
107
+
108
+ return annotation_json
109
+
110
+
111
+ def parse_floor_plan_polys(annos):
112
+ planes = []
113
+ for semantic in annos["semantics"]:
114
+ for planeID in semantic["planeID"]:
115
+ if annos["planes"][planeID]["type"] == "floor":
116
+ planes.append({"planeID": planeID, "type": semantic["type"]})
117
+
118
+ # if semantic["type"] == "outwall":
119
+ # outerwall_planes = semantic["planeID"]
120
+
121
+ # extract hole vertices
122
+ lines_holes = []
123
+ for semantic in annos["semantics"]:
124
+ if semantic["type"] in ["window", "door"]:
125
+ for planeID in semantic["planeID"]:
126
+ lines_holes.extend(np.where(np.array(annos["planeLineMatrix"][planeID]))[0].tolist())
127
+ lines_holes = np.unique(lines_holes)
128
+
129
+ ## junctions on the floor
130
+ # junctions = np.array([junc["coordinate"] for junc in annos["junctions"]])
131
+
132
+ # construct each polygon
133
+ polygons = []
134
+ for plane in planes:
135
+ lineIDs = np.where(np.array(annos["planeLineMatrix"][plane["planeID"]]))[0].tolist()
136
+ junction_pairs = [np.where(np.array(annos["lineJunctionMatrix"][lineID]))[0].tolist() for lineID in lineIDs]
137
+ polygon = convert_lines_to_vertices(junction_pairs)
138
+ polygons.append([polygon[0], plane["type"]])
139
+
140
+ return polygons
141
+
142
+
143
+ def convert_lines_to_vertices(lines):
144
+ """
145
+ convert line representation to polygon vertices
146
+
147
+ """
148
+ polygons = []
149
+ lines = np.array(lines)
150
+
151
+ polygon = None
152
+ while len(lines) != 0:
153
+ if polygon is None:
154
+ polygon = lines[0].tolist()
155
+ lines = np.delete(lines, 0, 0)
156
+
157
+ lineID, juncID = np.where(lines == polygon[-1])
158
+ vertex = lines[lineID[0], 1 - juncID[0]]
159
+ lines = np.delete(lines, lineID, 0)
160
+
161
+ if vertex in polygon:
162
+ polygons.append(polygon)
163
+ polygon = None
164
+ else:
165
+ polygon.append(vertex)
166
+
167
+ return polygons
168
+
169
+
170
+ def generate_coco_dict(annos, polygons, curr_instance_id, curr_img_id, ignore_types):
171
+
172
+ junctions = np.array([junc["coordinate"][:2] for junc in annos["junctions"]])
173
+
174
+ coco_annotation_dict_list = []
175
+
176
+ for poly_ind, (polygon, poly_type) in enumerate(polygons):
177
+ if poly_type in ignore_types:
178
+ continue
179
+
180
+ polygon = junctions[np.array(polygon)]
181
+
182
+ poly_shapely = Polygon(polygon)
183
+ area = poly_shapely.area
184
+
185
+ # assert area > 10
186
+ # if area < 100:
187
+ if poly_type not in ["door", "window"] and area < 100:
188
+ continue
189
+ if poly_type in ["door", "window"] and area < 1:
190
+ continue
191
+
192
+ rectangle_shapely = poly_shapely.envelope
193
+
194
+ ### here we convert door/window annotation into a single line
195
+ if poly_type in ["door", "window"]:
196
+ assert polygon.shape[0] == 4
197
+ midp_1 = (polygon[0] + polygon[1]) / 2
198
+ midp_2 = (polygon[1] + polygon[2]) / 2
199
+ midp_3 = (polygon[2] + polygon[3]) / 2
200
+ midp_4 = (polygon[3] + polygon[0]) / 2
201
+
202
+ dist_1_3 = np.square(midp_1 - midp_3).sum()
203
+ dist_2_4 = np.square(midp_2 - midp_4).sum()
204
+ if dist_1_3 > dist_2_4:
205
+ polygon = np.row_stack([midp_1, midp_3])
206
+ else:
207
+ polygon = np.row_stack([midp_2, midp_4])
208
+
209
+ coco_seg_poly = []
210
+ poly_sorted = resort_corners(polygon)
211
+
212
+ for p in poly_sorted:
213
+ coco_seg_poly += list(p)
214
+
215
+ # Slightly wider bounding box
216
+ bound_pad = 2
217
+ bb_x, bb_y = rectangle_shapely.exterior.xy
218
+ bb_x = np.unique(bb_x)
219
+ bb_y = np.unique(bb_y)
220
+ bb_x_min = np.maximum(np.min(bb_x) - bound_pad, 0)
221
+ bb_y_min = np.maximum(np.min(bb_y) - bound_pad, 0)
222
+
223
+ bb_x_max = np.minimum(np.max(bb_x) + bound_pad, 256 - 1)
224
+ bb_y_max = np.minimum(np.max(bb_y) + bound_pad, 256 - 1)
225
+
226
+ bb_width = bb_x_max - bb_x_min
227
+ bb_height = bb_y_max - bb_y_min
228
+
229
+ coco_bb = [bb_x_min, bb_y_min, bb_width, bb_height]
230
+
231
+ coco_annotation_dict = {
232
+ "segmentation": [coco_seg_poly],
233
+ "area": area,
234
+ "iscrowd": 0,
235
+ "image_id": curr_img_id,
236
+ "bbox": coco_bb,
237
+ "category_id": type2id[poly_type],
238
+ "id": curr_instance_id,
239
+ }
240
+
241
+ coco_annotation_dict_list.append(coco_annotation_dict)
242
+ curr_instance_id += 1
243
+
244
+ return coco_annotation_dict_list
data_preprocess/tools/plot_data.sh ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # Additional useful arguments:
4
+ # --crop_white_space: remove redundant whitespace from the rendering
5
+ # --one_color: use single color for every room (i.e. yellow)
6
+ # --compute_stats: compute statistics of the dataset (e.g. max_num_pts, max_num_polys)
7
+ # and plot histogram for counting number of Points, Rooms, Corners
8
+ # --drop_wd: disable Windor & Door in the plots
9
+ # --image_scale: adjust rendering resolution of the plots
10
+
11
+ SPLIT=test
12
+ python plot_floor.py --dataset_name=stru3d \
13
+ --dataset_root=data/coco_s3d_bw/ \
14
+ --eval_set=${SPLIT} \
15
+ --output_dir=data_plots/output_gt_s3dbw/${SPLIT} \
16
+ --semantic_classes=19 \
17
+ --input_channels 3 \
18
+ --disable_image_transform \
19
+ --poly2seq \
20
+ --image_size 256 \
21
+ --image_scale 1 \
22
+ --compute_stats \
23
+ --plot_gt \
24
+ --plot_gt_image \
25
+ --plot_polys \
26
+ --plot_density
27
+
28
+
29
+ SPLIT=test
30
+ python plot_floor.py --dataset_name=r2g \
31
+ --dataset_root=data/R2G_hr_dataset_processed_v1/ \
32
+ --eval_set=${SPLIT} \
33
+ --output_dir=output_gt_r2g/${SPLIT} \
34
+ --semantic_classes=13 \
35
+ --input_channels 3 \
36
+ --poly2seq \
37
+ --disable_image_transform \
38
+ --image_size 256 \
39
+ --image_scale 1 \
40
+ --compute_stats \
41
+ --plot_gt \
42
+ --plot_polys \
43
+ --plot_density
44
+
45
+
46
+ SPLIT=test
47
+ python plot_floor.py --dataset_name=cubicasa \
48
+ --dataset_root=data/coco_cubicasa5k_nowalls_v4-1_refined \
49
+ --eval_set=${SPLIT} \
50
+ --output_dir=data_plots/output_gt_cc5k/${SPLIT} \
51
+ --semantic_classes=12 \
52
+ --input_channels 3 \
53
+ --disable_image_transform \
54
+ --poly2seq \
55
+ --image_size 256 \
56
+ --image_scale 1 \
57
+ --compute_stats \
58
+ --plot_gt \
59
+ --plot_polys \
60
+ --plot_density
data_preprocess/tools/run_cc5k.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # create COCO-style dataset for CubiCasa5k
2
+ python -m data_preprocess.cubicasa5k.create_coco_cc5k --data_root=data/cubicasa5k/ \
3
+ --output=data/coco_cubicasa5k_nowalls_v4/ \
4
+ --disable_wd2line
5
+
6
+ # Split example has more than 1 floorplan into separate samples
7
+ python -m data_preprocess.cubicasa5k.create_coco_cc5k.floorplan_extraction \
8
+ --data_root data/coco_cubicasa5k_nowalls_v4/ \
9
+ --output data/coco_cubicasa5k_nowalls_v4-1_refined/
10
+
11
+ # Merge individual JSONs into single JSON file per split (train/val/test)
12
+ # This must be done after floorplan_extraction.py
13
+ python -m data_preprocess.cubicasa5k.combine_json \
14
+ --input data/coco_cubicasa5k_nowalls_v4-1_refined/ \
15
+ --output data/coco_cubicasa5k_nowalls_v4-1_refined/annotations/ \
data_preprocess/tools/run_r2g.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # preprocess raw Raster2Graph high-resolution dataset
2
+ python -m data_preprocess.raster2graph.image_process --data_root=data/R2G_hr_dataset/
3
+
4
+ # convert to COCO-style dataset
5
+ python -m data_preprocess.raster2graph.convert_to_coco --dataset_path data/R2G_hr_dataset/ --output_dir data/R2G_hr_dataset_processed/
6
+
7
+ # combine JSON files into single JSON file per split
8
+ python -m data_preprocess.raster2graph.combine_json \
9
+ --input data/R2G_hr_dataset_processed/ \
10
+ --output data/R2G_hr_dataset_processed_v1/ \
11
+
12
+ rm -rf data/R2G_hr_dataset_processed/
data_preprocess/tools/run_s3d.sh ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Assume the Structured3D density dataset are downloaded
2
+ DATA=data/coco_s3d
3
+
4
+ for split in train val test; do
5
+ python plot_floor.py --dataset_name=stru3d \
6
+ --dataset_root=${DATA} \
7
+ --eval_set=${split} \
8
+ --output_dir=data/coco_s3d_bw/${split}/ \
9
+ --semantic_classes=19 \
10
+ --input_channels 3 \
11
+ --disable_image_transform \
12
+ --poly2seq \
13
+ --image_size 256 \
14
+ --image_scale 1 \
15
+ --plot_gt \
16
+ --is_bw \
17
+ --plot_engine matplotlib
18
+
19
+ done
20
+
21
+ # Reuse the annotations
22
+ cp -r data/coco_s3d/annotations data/coco_s3d_bw/
data_preprocess/tools/run_waffle.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ python -m data_preprocess.waffle.create_coco_waffle_benchmark \
2
+ --data_root data/waffle/benchmark/ \
3
+ --output data/waffle_benchmark_processed/
data_preprocess/waffle/create_coco_waffle_benchmark.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import sys
5
+ from glob import glob
6
+ from pathlib import Path
7
+
8
+ import cv2
9
+ import numpy as np
10
+ from PIL import Image
11
+ from shapely.geometry import Polygon
12
+
13
+ sys.path.append(str(Path(__file__).resolve().parent.parent))
14
+ from common_utils import resort_corners
15
+
16
+
17
+ def draw_polygon_on_image(image, polygons, class_to_color):
18
+ """
19
+ Draws polygons on the image based on the COLOR_TO_CLASS mapping.
20
+
21
+ Args:
22
+ image (numpy.ndarray): The image on which to draw.
23
+ polygons (list of list of tuple): List of polygons, where each polygon is a list of (x, y) points.
24
+
25
+ Returns:
26
+ numpy.ndarray: The image with polygons drawn.
27
+ """
28
+ # Draw each polygon on the image
29
+ for polygon, polygon_class in polygons:
30
+ # Convert polygon points to numpy array
31
+ pts = np.array(polygon, dtype=np.int32).reshape(-1, 2)
32
+ color = class_to_color[polygon_class]
33
+ bgr = (color[2], color[1], color[0]) # Convert RGB to BGR for OpenCV
34
+ # Draw filled polygon
35
+ cv2.fillPoly(image, [pts], bgr)
36
+
37
+ return image
38
+
39
+
40
+ def fill_mask(segmentation_mask):
41
+ filled_mask = np.zeros_like(segmentation_mask, dtype=np.uint8)
42
+
43
+ # Iterate over each class index in the segmentation mask
44
+ for class_index in np.unique(segmentation_mask):
45
+ if class_index == 0: # Skip the background
46
+ continue
47
+
48
+ # Create a binary mask for the current class
49
+ binary_mask = (segmentation_mask == class_index).astype(np.uint8)
50
+
51
+ # Find contours for the current class
52
+ contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
53
+
54
+ # Fill each contour with white color in the single-channel mask
55
+ cv2.drawContours(filled_mask, contours, -1, 255, thickness=cv2.FILLED)
56
+
57
+ return filled_mask
58
+
59
+
60
+ def to_bw_image(input_image):
61
+ # Convert the input image to grayscale
62
+ gray_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2GRAY)
63
+
64
+ # Apply a binary threshold to convert the grayscale image to black and white
65
+ _, bw_image = cv2.threshold(gray_image, 127, 255, cv2.THRESH_BINARY)
66
+ return bw_image
67
+
68
+
69
+ def create_coco_bounding_box(bb_x, bb_y, image_width, image_height, bound_pad=2):
70
+ bb_x = np.unique(bb_x)
71
+ bb_y = np.unique(bb_y)
72
+ bb_x_min = np.maximum(np.min(bb_x) - bound_pad, 0)
73
+ bb_y_min = np.maximum(np.min(bb_y) - bound_pad, 0)
74
+
75
+ bb_x_max = np.minimum(np.max(bb_x) + bound_pad, image_width - 1)
76
+ bb_y_max = np.minimum(np.max(bb_y) + bound_pad, image_height - 1)
77
+
78
+ bb_width = bb_x_max - bb_x_min
79
+ bb_height = bb_y_max - bb_y_min
80
+
81
+ coco_bb = [bb_x_min, bb_y_min, bb_width, bb_height]
82
+ return coco_bb
83
+
84
+
85
+ def prepare_dict(categories_dict):
86
+ save_dict = {"images": [], "annotations": [], "categories": []}
87
+ for key, value in categories_dict.items():
88
+ type_dict = {"supercategory": "room", "id": value, "name": key}
89
+ save_dict["categories"].append(type_dict)
90
+ return save_dict
91
+
92
+
93
+ def convert_numpy_to_python(obj):
94
+ if isinstance(obj, np.integer):
95
+ return int(obj)
96
+ elif isinstance(obj, np.floating):
97
+ return float(obj)
98
+ elif isinstance(obj, np.ndarray):
99
+ return obj.tolist()
100
+ else:
101
+ return obj
102
+
103
+
104
+ def config():
105
+ a = argparse.ArgumentParser(description="Generate coco format data for WAFFLE BENCHMARK SET")
106
+ a.add_argument("--data_root", default="data/waffle/benchmark/", type=str, help="path to WAFFLE BENCHMARK folder")
107
+ a.add_argument("--output", default="data/waffle_benchmark_processed/", type=str, help="path to output folder")
108
+
109
+ args = a.parse_args()
110
+ return args
111
+
112
+
113
+ if __name__ == "__main__":
114
+ LABEL_NOTATIONS = {
115
+ "Background": (0, 0, 0), # Black
116
+ "Interior": (255, 255, 255), # White
117
+ "Walls": (255, 0, 0), # Red
118
+ "Doors": (0, 0, 255), # Blue
119
+ "Windows": (0, 255, 255), # Cyan
120
+ }
121
+
122
+ CLASS2INDEX = {
123
+ "Background": 0, # Black
124
+ "Interior": 1, # White
125
+ # "Walls": 2, # Red
126
+ "Doors": 3, # Blue
127
+ "Windows": 4, # Cyan
128
+ }
129
+
130
+ # Create a mapping from RGB values to class indices
131
+ COLOR_TO_CLASS = {
132
+ (0, 0, 0): 0, # Background
133
+ (255, 255, 255): 1, # Interior
134
+ (255, 0, 0): 2, # Walls
135
+ (0, 0, 255): 3, # Doors
136
+ (0, 255, 255): 4, # Windows
137
+ }
138
+
139
+ NEW_CLASS_MAPPING = {
140
+ 1: 0,
141
+ 3: 1,
142
+ 4: 2,
143
+ }
144
+
145
+ CLASS_TO_COLOR = {
146
+ 0: (255, 255, 255), # Interior
147
+ 1: (0, 0, 255), # Doors
148
+ 2: (0, 255, 255), # Windows
149
+ }
150
+
151
+ args = config()
152
+
153
+ root = args.data_root
154
+ image_dir = f"{root}/pngs"
155
+ label_dir = f"{root}/segmented_descrete_pngs"
156
+ input_paths = sorted(glob(f"{label_dir}/*.png"))
157
+
158
+ output_dir = args.output
159
+ output_aux_dir = f"{output_dir}/aux"
160
+ output_image_dir = f"{output_dir}/test/"
161
+ output_annot_dir = f"{output_dir}/annotations/"
162
+ fn_mapping_log = f"{output_annot_dir}/test_image_id_mapping.json"
163
+ os.makedirs(output_dir, exist_ok=True)
164
+ os.makedirs(output_aux_dir, exist_ok=True)
165
+ os.makedirs(output_image_dir, exist_ok=True)
166
+ os.makedirs(output_annot_dir, exist_ok=True)
167
+
168
+ instance_count = 0
169
+
170
+ save_dict = prepare_dict(CLASS2INDEX)
171
+ output_mappings = []
172
+
173
+ for i, path in enumerate(input_paths):
174
+ # if i > 5:
175
+ # exit(0)
176
+ mask = Image.open(path).convert("RGB")
177
+ fn = os.path.basename(path).replace("_seg_colors.png", "")
178
+ new_fn = str(i).zfill(5)
179
+
180
+ mask = np.array(mask)
181
+ image = Image.open(os.path.join(image_dir, f"{fn}.png")).convert("RGB")
182
+ image_width, image_height = image.size
183
+
184
+ # Initialize an empty segmentation mask with the same height and width as the input mask
185
+ segmentation_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.uint8)
186
+
187
+ img_id = i
188
+ img_dict = {}
189
+ img_dict["file_name"] = str(img_id).zfill(5) + ".png"
190
+ img_dict["id"] = img_id
191
+ img_dict["width"] = image_width
192
+ img_dict["height"] = image_height
193
+
194
+ output_polygons = []
195
+ coco_annotation_dict_list = []
196
+ # Iterate over each pixel in the mask and assign the corresponding class index
197
+ for color, class_index in COLOR_TO_CLASS.items():
198
+ # Create a boolean mask for the current color
199
+ color_mask = (mask == color).all(axis=-1)
200
+ color_mask_uint8 = color_mask.astype(np.uint8)
201
+
202
+ # Assign the class index to the segmentation mask
203
+ segmentation_mask[color_mask] = class_index
204
+
205
+ if class_index not in NEW_CLASS_MAPPING:
206
+ continue
207
+ class_index = NEW_CLASS_MAPPING[class_index]
208
+
209
+ # Find contours for the current color mask
210
+ contours, _ = cv2.findContours(color_mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
211
+ new_contours = []
212
+ for cnt in contours:
213
+ peri = cv2.arcLength(cnt, True)
214
+ approx = cv2.approxPolyDP(cnt, 0.001 * peri, True)
215
+ new_contours.append(approx)
216
+
217
+ # Convert contours to polygon coordinates
218
+ polygons = [contour.reshape(-1, 2) for contour in new_contours]
219
+
220
+ for polygon in polygons:
221
+ # Convert the polygon to a Shapely Polygon object
222
+ if polygon.shape[0] < 3:
223
+ continue
224
+
225
+ shapely_polygon = Polygon(polygon)
226
+ area = shapely_polygon.area
227
+ rectangle_shapely = shapely_polygon.envelope
228
+ bb_x, bb_y = rectangle_shapely.exterior.xy
229
+ coco_bb = create_coco_bounding_box(bb_x, bb_y, image_width, image_height, bound_pad=2)
230
+
231
+ if class_index in [3, 4] and area < 1:
232
+ continue
233
+ if class_index not in [3, 4] and area < 100:
234
+ continue
235
+
236
+ coco_seg_poly = []
237
+ poly_sorted = resort_corners(polygon)
238
+ # image = draw_polygon_on_image(image, poly_shapely, "test_poly.jpg")
239
+
240
+ for p in poly_sorted:
241
+ coco_seg_poly += list(p)
242
+
243
+ # Create a dictionary for the COCO annotation
244
+ coco_annotation_dict = {
245
+ "segmentation": [coco_seg_poly],
246
+ "area": area,
247
+ "iscrow": 0,
248
+ "image_id": i,
249
+ "bbox": coco_bb,
250
+ "category_id": class_index,
251
+ "id": instance_count,
252
+ }
253
+ coco_annotation_dict_list.append(coco_annotation_dict)
254
+ instance_count += 1
255
+ output_polygons.append([coco_seg_poly, class_index])
256
+
257
+ save_dict["images"].append(img_dict)
258
+ save_dict["annotations"] += coco_annotation_dict_list
259
+
260
+ # Print the unique class indices in the segmentation mask to verify
261
+ print(path)
262
+ print(np.unique(segmentation_mask))
263
+
264
+ filled_mask = fill_mask(segmentation_mask)
265
+
266
+ clean_image = np.array(image)
267
+ filled_mask_resized = cv2.resize(
268
+ filled_mask, (clean_image.shape[1], clean_image.shape[0]), interpolation=cv2.INTER_NEAREST
269
+ )
270
+ cv2.imwrite(f"{output_aux_dir}/{fn}_fg_mask.png", filled_mask_resized)
271
+
272
+ clean_image = clean_image * np.array(filled_mask_resized[:, :, np.newaxis] / 255.0).astype(bool)
273
+ clean_image[filled_mask_resized == 0] = 255
274
+ clean_image = cv2.cvtColor(clean_image, cv2.COLOR_RGB2BGR)
275
+ # clean_image = to_bw_image(clean_image)
276
+ cv2.imwrite(f"{output_image_dir}/{new_fn}.png", clean_image)
277
+
278
+ image_with_polygons = draw_polygon_on_image(np.zeros_like(clean_image), output_polygons, CLASS_TO_COLOR)
279
+ cv2.imwrite(f"{output_aux_dir}/{fn}_polylines.png", image_with_polygons)
280
+
281
+ output_mappings.append(f"{fn} {new_fn}")
282
+
283
+ with open(fn_mapping_log, "w") as f:
284
+ for mapping in output_mappings:
285
+ f.write(f"{mapping}\n")
286
+
287
+ # Serialize save_dict to JSON
288
+ json_path = f"{output_annot_dir}/test.json"
289
+ with open(json_path, "w") as f:
290
+ json.dump(save_dict, f, default=convert_numpy_to_python)
datasets/__init__.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .poly_data import build as build_poly
2
+
3
+
4
+ def build_dataset(image_set, args):
5
+ if args.dataset_name in ["stru3d", "cubicasa", "waffle", "r2g"]:
6
+ print(f"Build {args.dataset_name} {image_set} dataset")
7
+ return build_poly(image_set, args)
8
+ raise ValueError(f"dataset {args.dataset_name} not supported")
9
+
10
+
11
+ def get_dataset_class_labels(dataset_name):
12
+ semantics_label = None
13
+
14
+ if dataset_name == "stru3d":
15
+ semantics_label = {
16
+ 0: "Living Room",
17
+ 1: "Kitchen",
18
+ 2: "Bedroom",
19
+ 3: "Bathroom",
20
+ 4: "Balcony",
21
+ 5: "Corridor",
22
+ 6: "Dining room",
23
+ 7: "Study",
24
+ 8: "Studio",
25
+ 9: "Store room",
26
+ 10: "Garden",
27
+ 11: "Laundry room",
28
+ 12: "Office",
29
+ 13: "Basement",
30
+ 14: "Garage",
31
+ 15: "Misc.",
32
+ 16: "Door",
33
+ 17: "Window",
34
+ }
35
+ elif dataset_name == "cubicasa":
36
+ semantics_label = {
37
+ "Outdoor": 0,
38
+ "Kitchen": 1,
39
+ "Living Room": 2,
40
+ "Bed Room": 3,
41
+ "Bath": 4,
42
+ "Entry": 5,
43
+ "Storage": 6,
44
+ "Garage": 7,
45
+ "Undefined": 8,
46
+ "Window": 9,
47
+ "Door": 10,
48
+ }
49
+ elif dataset_name == "r2g":
50
+ semantics_label = {
51
+ "unknown": 0,
52
+ "living_room": 1,
53
+ "kitchen": 2,
54
+ "bedroom": 3,
55
+ "bathroom": 4,
56
+ "restroom": 5,
57
+ "balcony": 6,
58
+ "closet": 7,
59
+ "corridor": 8,
60
+ "washing_room": 9,
61
+ "PS": 10,
62
+ "outside": 11,
63
+ }
64
+
65
+ id2class = {v: k for k, v in semantics_label.items()} if semantics_label else None
66
+
67
+ return semantics_label, id2class
datasets/data_utils.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+
4
+
5
+ def compute_centroid(polygon):
6
+ """Compute centroid of a polygon given as list of (x, y)."""
7
+ polygon = np.array(polygon)
8
+ x = np.mean(polygon[:, 0])
9
+ y = np.mean(polygon[:, 1])
10
+ return (x, y)
11
+
12
+
13
+ def get_top_left(polygon):
14
+ return min(polygon, key=lambda p: (p[1], p[0])) # y ascending, x ascending
15
+
16
+
17
+ def sort_polygons(polygons, tolerance=20, reverse=False):
18
+ # Step 1: Get top-left corner and original index
19
+ indexed = [(i, get_top_left(p), p) for i, p in enumerate(polygons)]
20
+
21
+ # Step 2: Sort by Y (top to bottom)
22
+ indexed.sort(key=lambda x: x[1][1])
23
+
24
+ # Step 3: Group into rows
25
+ rows = []
26
+ for idx, corner, poly in indexed:
27
+ y = corner[1]
28
+ added = False
29
+ for row in rows:
30
+ if abs(row[0][1][1] - y) <= tolerance:
31
+ row.append((idx, corner, poly))
32
+ added = True
33
+ break
34
+ if not added:
35
+ rows.append([(idx, corner, poly)])
36
+
37
+ # Step 4: Sort each row left-to-right
38
+ for row in rows:
39
+ row.sort(key=lambda x: x[1][0]) # sort by x
40
+
41
+ # Step 5: Flatten and return indices
42
+ sorted_indices = [idx for row in rows for idx, _, _ in row]
43
+ if reverse:
44
+ sorted_indices = sorted_indices[::-1]
45
+ sorted_polygons = [polygons[idx] for idx in sorted_indices]
46
+
47
+ return sorted_polygons, sorted_indices
48
+
49
+
50
+ def plot_polygons(polygons, save_path):
51
+ plt.figure(figsize=(6, 6))
52
+ for i, poly in enumerate(polygons):
53
+ poly = np.array(poly)
54
+ plt.fill(poly[:, 0], poly[:, 1], alpha=0.5, label=f"Polygon {i + 1}")
55
+ centroid = compute_centroid(poly)
56
+ plt.text(centroid[0], centroid[1], f"C{i + 1}", fontsize=10, ha="center")
57
+ # plt.title(title)
58
+ # plt.legend()
59
+ plt.gca().set_aspect("equal", adjustable="box")
60
+ plt.savefig(save_path)
datasets/discrete_tokenizer.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+
5
+ class DiscreteTokenizer(object):
6
+ def __init__(self, num_bins, seq_len, add_cls=False):
7
+ self.num_bins = num_bins
8
+ vocab_size = num_bins * num_bins
9
+ self.seq_len = seq_len
10
+ self.add_cls = add_cls
11
+
12
+ self.bos = vocab_size + 0
13
+ self.eos = vocab_size + 1
14
+ self.sep = vocab_size + 2
15
+ self.pad = vocab_size + 3
16
+ if add_cls:
17
+ self.cls = vocab_size + 4
18
+ self.vocab_size = vocab_size + 5
19
+ else:
20
+ self.vocab_size = vocab_size + 4
21
+
22
+ def __len__(self):
23
+ return self.vocab_size
24
+
25
+ def _padding(self, seq, pad_value, dtype):
26
+ if self.seq_len > len(seq):
27
+ seq.extend([pad_value] * (self.seq_len - len(seq)))
28
+ return torch.tensor(np.array(seq), dtype=dtype)
29
+
30
+ def __call__(self, seq, add_bos, add_eos, dtype, return_indices=False):
31
+ out = []
32
+ if add_bos:
33
+ out = [self.bos]
34
+ num_extra = 1 if not self.add_cls else 2 # cls and sep
35
+ indices = []
36
+ for i, sub in enumerate(seq):
37
+ cur_len = len(out)
38
+ # Append sub only if it doesn't exceed seq_len
39
+ if cur_len + len(sub) + num_extra <= self.seq_len:
40
+ out.extend(sub)
41
+ indices.append(i)
42
+ else:
43
+ continue
44
+ # Append cls and sep tokens only if it doesn't exceed seq_len
45
+ if self.add_cls:
46
+ out.append(self.cls) # cls token
47
+ out.append(self.sep)
48
+ # Remove last separator token if present
49
+ if out and out[-1] == self.sep:
50
+ out.pop(-1) # remove last separator token
51
+
52
+ if self.seq_len > len(out):
53
+ out.extend([self.pad] * (self.seq_len - len(out)))
54
+
55
+ if add_eos:
56
+ out[-1] = self.eos
57
+
58
+ if return_indices:
59
+ return torch.tensor(out, dtype=dtype), indices
60
+ return torch.tensor(out, dtype=dtype)
datasets/poly_data.py ADDED
@@ -0,0 +1,590 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ from enum import Enum
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.utils.data
9
+ import torchvision
10
+ from PIL import Image
11
+ from pycocotools.coco import COCO
12
+ from torch.utils.data import Dataset
13
+
14
+ from datasets.data_utils import sort_polygons
15
+ from datasets.discrete_tokenizer import DiscreteTokenizer
16
+ from datasets.transforms import ResizeAndPad
17
+ from detectron2.data import transforms as T
18
+ from detectron2.data.detection_utils import annotations_to_instances, transform_instance_annotations
19
+ from detectron2.structures import BoxMode
20
+ from util.poly_ops import resort_corners
21
+
22
+
23
+ class TokenType(Enum):
24
+ """0 for <coord>, 1 for <sep>, 2 for <eos>, 3 for <cls>"""
25
+
26
+ coord = 0
27
+ sep = 1
28
+ eos = 2
29
+ cls = 3
30
+
31
+
32
+ WD_INDEX = {
33
+ "stru3d": [16, 17],
34
+ "cubicasa": [9, 10],
35
+ "waffle": [],
36
+ "r2g": [],
37
+ }
38
+
39
+
40
+ class MultiPoly(Dataset):
41
+ def __init__(
42
+ self,
43
+ img_folder,
44
+ ann_file,
45
+ transforms,
46
+ semantic_classes,
47
+ dataset_name="",
48
+ image_norm=False,
49
+ poly2seq=False,
50
+ converter_version="v1",
51
+ random_drop_rate=0.0,
52
+ **kwargs,
53
+ ):
54
+ super(MultiPoly, self).__init__()
55
+
56
+ self.root = img_folder
57
+ self._transforms = transforms
58
+ self.semantic_classes = semantic_classes
59
+ self.dataset_name = dataset_name
60
+
61
+ self.coco = COCO(ann_file)
62
+ self.ids = list(sorted(self.coco.imgs.keys()))
63
+
64
+ self.poly2seq = poly2seq
65
+ self.prepare = ConvertToCocoDictWithOrder_plus(
66
+ self.root,
67
+ self._transforms,
68
+ image_norm,
69
+ poly2seq,
70
+ semantic_classes=semantic_classes,
71
+ order_type=["l2r", "r2l"][converter_version == "v3_flipped"],
72
+ random_drop_rate=random_drop_rate,
73
+ **kwargs,
74
+ )
75
+
76
+ def get_image(self, path):
77
+ return Image.open(os.path.join(self.root, path))
78
+
79
+ def get_vocab_size(self):
80
+ if self.poly2seq:
81
+ return len(self.prepare.tokenizer)
82
+ return None
83
+
84
+ def get_tokenizer(self):
85
+ if self.poly2seq:
86
+ return self.prepare.tokenizer
87
+ return None
88
+
89
+ def __len__(self):
90
+ return len(self.ids)
91
+
92
+ def __getitem__(self, index):
93
+ """
94
+ Args:
95
+ index (int): Index
96
+ Returns:
97
+ dict: COCO format dict
98
+ """
99
+ coco = self.coco
100
+ img_id = self.ids[index]
101
+
102
+ ann_ids = coco.getAnnIds(imgIds=img_id)
103
+ target = coco.loadAnns(ann_ids)
104
+
105
+ ### Note: here is a hack which assumes door/window have category_id 16, 17 in structured3D
106
+ if self.semantic_classes == -1:
107
+ if self.dataset_name == "stru3d":
108
+ target = [t for t in target if t["category_id"] not in WD_INDEX["stru3d"]]
109
+ # elif self.dataset_name == 'rplan':
110
+ # target = [t for t in target if t['category_id'] not in [9, 11]]
111
+ elif self.dataset_name == "cubicasa":
112
+ target = [t for t in target if t["category_id"] not in WD_INDEX["cubicasa"]]
113
+
114
+ path = coco.loadImgs(img_id)[0]["file_name"]
115
+
116
+ record = self.prepare(img_id, path, target)
117
+
118
+ return record
119
+
120
+
121
+ class MultiPolyWD(MultiPoly):
122
+ def __getitem__(self, index):
123
+ """
124
+ Args:
125
+ index (int): Index
126
+ Returns:
127
+ dict: COCO format dict
128
+ """
129
+ coco = self.coco
130
+ img_id = self.ids[index]
131
+
132
+ ann_ids = coco.getAnnIds(imgIds=img_id)
133
+ target = coco.loadAnns(ann_ids)
134
+
135
+ ### Note: here is a hack which assumes door/window have category_id 16, 17 in structured3D
136
+ # if self.semantic_classes == -1:
137
+ # if self.dataset_name == 'stru3d':
138
+ # target = [t for t in target if t['category_id'] not in [16, 17]]
139
+ # elif self.dataset_name == 'rplan':
140
+ # target = [t for t in target if t['category_id'] not in [9, 11]]
141
+ # elif self.dataset_name == 'cubicasa':
142
+ # target = [t for t in target if t['category_id'] not in [9, 10]]
143
+
144
+ if self.dataset_name == "stru3d":
145
+ target = [t for t in target if t["category_id"] in [16, 17]]
146
+ elif self.dataset_name == "rplan":
147
+ target = [t for t in target if t["category_id"] in [9, 11]]
148
+ elif self.dataset_name == "cubicasa":
149
+ target = [t for t in target if t["category_id"] in [9, 10]]
150
+
151
+ path = coco.loadImgs(img_id)[0]["file_name"]
152
+ record = self.prepare(img_id, path, target)
153
+
154
+ return record
155
+
156
+
157
+ class ConvertToCocoDict(object):
158
+ def __init__(
159
+ self,
160
+ root,
161
+ augmentations,
162
+ image_norm,
163
+ poly2seq=False,
164
+ semantic_classes=-1,
165
+ add_cls_token=False,
166
+ per_token_class=False,
167
+ mask_format="polygon",
168
+ **kwargs,
169
+ ):
170
+ self.root = root
171
+ self.augmentations = augmentations
172
+ if image_norm:
173
+ self.image_normalize = torchvision.transforms.Normalize(
174
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
175
+ )
176
+ else:
177
+ self.image_normalize = None
178
+
179
+ self.semantic_classes = semantic_classes
180
+ self.poly2seq = poly2seq
181
+ if poly2seq:
182
+ self.tokenizer = DiscreteTokenizer(add_cls=add_cls_token, **kwargs)
183
+ self.add_cls_token = add_cls_token
184
+ self.per_token_class = per_token_class
185
+ self.mask_format = mask_format
186
+
187
+ def _expand_image_dims(self, x):
188
+ if len(x.shape) == 2:
189
+ exp_img = np.expand_dims(x, 0)
190
+ else:
191
+ exp_img = x.transpose((2, 0, 1)) # (h,w,c) -> (c,h,w)
192
+ return exp_img
193
+
194
+ def __call__(self, img_id, path, target):
195
+
196
+ file_name = os.path.join(self.root, path)
197
+
198
+ img = np.array(Image.open(file_name))
199
+
200
+ #### NEW
201
+ if len(img.shape) >= 3:
202
+ if img.shape[-1] > 3: # drop alpha channel
203
+ img = img[:, :, :3]
204
+ w, h = img.shape[:-1]
205
+ else:
206
+ # print(img.shape, file_name)
207
+ w, h = img.shape
208
+ #### NEW
209
+
210
+ record = {}
211
+ record["file_name"] = file_name
212
+ record["height"] = h
213
+ record["width"] = w
214
+ record["image_id"] = img_id
215
+
216
+ for obj in target:
217
+ obj["bbox_mode"] = BoxMode.XYWH_ABS
218
+
219
+ record["annotations"] = target
220
+
221
+ if self.augmentations is None:
222
+ record["image"] = (1 / 255) * torch.as_tensor(np.ascontiguousarray(self._expand_image_dims(img)))
223
+ record["instances"] = annotations_to_instances(target, (h, w), mask_format=self.mask_format)
224
+ else:
225
+ aug_input = T.AugInput(img)
226
+ transforms = self.augmentations(aug_input)
227
+ image = aug_input.image
228
+ record["image"] = (1 / 255) * torch.as_tensor(np.array(self._expand_image_dims(image)))
229
+ h, w = image.shape[:2] # update size
230
+
231
+ annos = [
232
+ transform_instance_annotations(obj, transforms, image.shape[:2])
233
+ for obj in record.pop("annotations")
234
+ if obj.get("iscrowd", 0) == 0
235
+ ]
236
+ # resort corners after augmentation: so that all corners start from upper-left counterclockwise
237
+ for anno in annos:
238
+ anno["segmentation"][0] = resort_corners(anno["segmentation"][0])
239
+
240
+ record["instances"] = annotations_to_instances(annos, (h, w), mask_format=self.mask_format)
241
+
242
+ #### NEW ####
243
+ if self.image_normalize is not None:
244
+ record["image"] = self.image_normalize(record["image"])
245
+
246
+ # convert polygons to sequences
247
+ if self.poly2seq:
248
+ # only happend for wdonly
249
+ if not hasattr(record["instances"], "gt_masks"):
250
+ polygons = [np.array([[0.0, 0.0]])]
251
+ polygons_label = [self.semantic_classes - 1] # dummy class
252
+ else:
253
+ polygons = [
254
+ np.clip(np.array(inst).reshape(-1, 2) / (w - 1), 0, 1)
255
+ for inst in record["instances"].gt_masks.polygons
256
+ ]
257
+ polygons_label = [inst.item() for inst in record["instances"].gt_classes]
258
+ record.update(
259
+ self._get_bilinear_interpolation_coeffs(
260
+ polygons, polygons_label, self.add_cls_token, self.per_token_class
261
+ )
262
+ )
263
+
264
+ return record
265
+
266
+ def _get_bilinear_interpolation_coeffs(self, polygons, polygons_label, add_cls_token=False, per_token_class=False):
267
+ num_bins = self.tokenizer.num_bins
268
+ quant_poly = [poly * (num_bins - 1) for poly in polygons]
269
+ index11 = [[math.floor(p[0]) * num_bins + math.floor(p[1]) for p in poly] for poly in quant_poly]
270
+ index21 = [[math.ceil(p[0]) * num_bins + math.floor(p[1]) for p in poly] for poly in quant_poly]
271
+ index12 = [[math.floor(p[0]) * num_bins + math.ceil(p[1]) for p in poly] for poly in quant_poly]
272
+ index22 = [[math.ceil(p[0]) * num_bins + math.ceil(p[1]) for p in poly] for poly in quant_poly]
273
+
274
+ seq11 = self.tokenizer(index11, add_bos=True, add_eos=False, dtype=torch.long)
275
+ seq21 = self.tokenizer(index21, add_bos=True, add_eos=False, dtype=torch.long)
276
+ seq12 = self.tokenizer(index12, add_bos=True, add_eos=False, dtype=torch.long)
277
+ seq22 = self.tokenizer(index22, add_bos=True, add_eos=False, dtype=torch.long)
278
+
279
+ # in real values insteads
280
+ target_seq = []
281
+ token_labels = [] # 0 for <coord>, 1 for <sep>, 2 for <eos>, 3 for <cls>
282
+ num_extra = 1 if not add_cls_token else 2 # cls and sep
283
+ count_polys = 0
284
+ for poly in polygons:
285
+ cur_len = len(token_labels)
286
+ if cur_len + len(poly) + num_extra > self.tokenizer.seq_len:
287
+ break # INFO: change from break to continue
288
+ token_labels.extend([TokenType.coord.value] * len(poly))
289
+ if add_cls_token:
290
+ token_labels.append(TokenType.cls.value) # cls token
291
+ token_labels.append(TokenType.sep.value) # separator token
292
+ target_seq.extend(poly)
293
+ if add_cls_token:
294
+ target_seq.append([0, 0]) # padding for cls token
295
+ target_seq.append([0, 0]) # padding for sep/end token
296
+ count_polys += 1
297
+ # remove last separator token
298
+ if len(token_labels) > 0:
299
+ token_labels[-1] = TokenType.eos.value
300
+ mask = torch.ones(self.tokenizer.seq_len, dtype=torch.bool)
301
+ if len(token_labels) < self.tokenizer.seq_len:
302
+ mask[len(token_labels) :] = 0
303
+ target_seq = self.tokenizer._padding(target_seq, [0, 0], dtype=torch.float32)
304
+ token_labels = self.tokenizer._padding(token_labels, -1, dtype=torch.long)
305
+
306
+ delta_x1 = [0] # [0] for bos token
307
+ for polygon in quant_poly[:count_polys]:
308
+ delta = [poly_point[0] - math.floor(poly_point[0]) for poly_point in polygon]
309
+ delta_x1.extend(delta)
310
+ if add_cls_token:
311
+ delta_x1.extend([0]) # for cls token
312
+ delta_x1.extend([0]) # for separator token
313
+ delta_x1 = delta_x1[:-1] # there is no separator token in the end
314
+ delta_x1 = self.tokenizer._padding(delta_x1, 0, dtype=torch.float32)
315
+ delta_x2 = 1 - delta_x1
316
+
317
+ delta_y1 = [0] # [0] for bos token
318
+ for polygon in quant_poly[:count_polys]:
319
+ delta = [poly_point[1] - math.floor(poly_point[1]) for poly_point in polygon]
320
+ delta_y1.extend(delta)
321
+ if add_cls_token:
322
+ delta_y1.extend([0]) # for cls token
323
+ delta_y1.extend([0]) # for separator token
324
+ delta_y1 = delta_y1[:-1] # there is no separator token in the end
325
+ delta_y1 = self.tokenizer._padding(delta_y1, 0, dtype=torch.float32)
326
+ delta_y2 = 1 - delta_y1
327
+
328
+ if not per_token_class:
329
+ target_polygon_labels = polygons_label[:count_polys]
330
+ else:
331
+ target_polygon_labels = []
332
+ for poly, poly_label in zip(quant_poly[:count_polys], polygons_label[:count_polys]):
333
+ target_polygon_labels.extend([poly_label] * len(poly))
334
+ target_polygon_labels.append(self.semantic_classes - 1) # undefined class for <sep> and <eos> token
335
+
336
+ max_label_length = self.tokenizer.seq_len
337
+ if len(polygons_label) < max_label_length:
338
+ target_polygon_labels.extend([-1] * (max_label_length - len(target_polygon_labels)))
339
+
340
+ target_polygon_labels = torch.tensor(target_polygon_labels, dtype=torch.long)
341
+
342
+ return {
343
+ "delta_x1": delta_x1,
344
+ "delta_x2": delta_x2,
345
+ "delta_y1": delta_y1,
346
+ "delta_y2": delta_y2,
347
+ "seq11": seq11,
348
+ "seq21": seq21,
349
+ "seq12": seq12,
350
+ "seq22": seq22,
351
+ "target_seq": target_seq,
352
+ "token_labels": token_labels,
353
+ "mask": mask,
354
+ "target_polygon_labels": target_polygon_labels,
355
+ }
356
+
357
+
358
+ class ConvertToCocoDictWithOrder_plus(ConvertToCocoDict):
359
+ def __init__(
360
+ self,
361
+ root,
362
+ augmentations,
363
+ image_norm,
364
+ poly2seq=False,
365
+ semantic_classes=-1,
366
+ add_cls_token=False,
367
+ per_token_class=False,
368
+ mask_format="polygon",
369
+ dataset_name="stru3d",
370
+ order_type="l2r",
371
+ random_drop_rate=0.0,
372
+ **kwargs,
373
+ ):
374
+ super().__init__(
375
+ root,
376
+ augmentations,
377
+ image_norm,
378
+ poly2seq,
379
+ semantic_classes,
380
+ add_cls_token,
381
+ per_token_class,
382
+ mask_format,
383
+ **kwargs,
384
+ )
385
+ self.dataset_name = dataset_name
386
+ self.order_type = order_type # l2r, r2l
387
+ self.random_drop_rate = random_drop_rate
388
+ self.tokenizer = DiscreteTokenizer(add_cls=add_cls_token, **kwargs)
389
+
390
+ def _get_bilinear_interpolation_coeffs(self, polygons, polygons_label, add_cls_token=False, per_token_class=False):
391
+ num_bins = self.tokenizer.num_bins
392
+ room_indices = [
393
+ poly_idx
394
+ for poly_idx, poly_label in enumerate(polygons_label)
395
+ if poly_label not in WD_INDEX[self.dataset_name]
396
+ ]
397
+ wd_indices = [
398
+ poly_idx for poly_idx, poly_label in enumerate(polygons_label) if poly_label in WD_INDEX[self.dataset_name]
399
+ ]
400
+
401
+ _, room_sorted_indices = sort_polygons(
402
+ [polygons[poly_idx] for poly_idx in room_indices], reverse=(self.order_type == "r2l")
403
+ )
404
+ _, wd_sorted_indices = sort_polygons(
405
+ [polygons[poly_idx] for poly_idx in wd_indices], reverse=(self.order_type == "r2l")
406
+ )
407
+ room_indices = [room_indices[_idx] for _idx in room_sorted_indices]
408
+ wd_indices = [wd_indices[_idx] for _idx in wd_sorted_indices]
409
+
410
+ #### NEW ####
411
+ combined_indices = room_indices + wd_indices # room first
412
+ if self.random_drop_rate > 0 and len(combined_indices) > 2:
413
+ keep_indices = np.where(np.random.rand(len(combined_indices)) >= self.random_drop_rate)[0].tolist()
414
+ if len(keep_indices) > 0: # Only apply drop if we have something left
415
+ combined_indices = [combined_indices[i] for i in keep_indices]
416
+ #### NEW ####
417
+
418
+ polygons = [polygons[i] for i in combined_indices]
419
+ polygons_label = [polygons_label[i] for i in combined_indices]
420
+
421
+ quant_poly = [poly * (num_bins - 1) for poly in polygons]
422
+ index11 = [[math.floor(p[0]) * num_bins + math.floor(p[1]) for p in poly] for poly in quant_poly]
423
+ index21 = [[math.ceil(p[0]) * num_bins + math.floor(p[1]) for p in poly] for poly in quant_poly]
424
+ index12 = [[math.floor(p[0]) * num_bins + math.ceil(p[1]) for p in poly] for poly in quant_poly]
425
+ index22 = [[math.ceil(p[0]) * num_bins + math.ceil(p[1]) for p in poly] for poly in quant_poly]
426
+
427
+ seq11 = self.tokenizer(index11, add_bos=True, add_eos=False, dtype=torch.long)
428
+ seq21 = self.tokenizer(index21, add_bos=True, add_eos=False, dtype=torch.long)
429
+ seq12 = self.tokenizer(index12, add_bos=True, add_eos=False, dtype=torch.long)
430
+ seq22, poly_indices = self.tokenizer(
431
+ index22, add_bos=True, add_eos=False, dtype=torch.long, return_indices=True
432
+ )
433
+
434
+ # in real values insteads
435
+ target_seq = []
436
+ token_labels = [] # 0 for <coord>, 1 for <sep>, 2 for <eos>, 3 for <cls>
437
+
438
+ for i in poly_indices:
439
+ token_labels.extend([TokenType.coord.value] * len(polygons[i]))
440
+ if add_cls_token:
441
+ token_labels.append(TokenType.cls.value) # cls token
442
+ token_labels.append(TokenType.sep.value) # separator token
443
+ target_seq.extend(polygons[i])
444
+ if add_cls_token:
445
+ target_seq.append([0, 0]) # padding for cls token
446
+ target_seq.append([0, 0]) # padding for sep/end token
447
+ # remove last separator token
448
+ token_labels[-1] = TokenType.eos.value
449
+
450
+ mask = torch.ones(self.tokenizer.seq_len, dtype=torch.bool)
451
+ if len(token_labels) < self.tokenizer.seq_len:
452
+ mask[len(token_labels) :] = 0
453
+ target_seq = self.tokenizer._padding(target_seq, [0, 0], dtype=torch.float32)
454
+ token_labels = self.tokenizer._padding(token_labels, -1, dtype=torch.long)
455
+
456
+ delta_x1 = [0] # [0] for bos token
457
+ for i in poly_indices:
458
+ polygon = quant_poly[i]
459
+ delta = [poly_point[0] - math.floor(poly_point[0]) for poly_point in polygon]
460
+ delta_x1.extend(delta)
461
+ if add_cls_token:
462
+ delta_x1.extend([0]) # for cls token
463
+ delta_x1.extend([0]) # for separator token
464
+ delta_x1 = delta_x1[:-1] # there is no separator token in the end
465
+ delta_x1 = self.tokenizer._padding(delta_x1, 0, dtype=torch.float32)
466
+ delta_x2 = 1 - delta_x1
467
+
468
+ delta_y1 = [0] # [0] for bos token
469
+ for i in poly_indices:
470
+ polygon = quant_poly[i]
471
+ delta = [poly_point[1] - math.floor(poly_point[1]) for poly_point in polygon]
472
+ delta_y1.extend(delta)
473
+ if add_cls_token:
474
+ delta_y1.extend([0]) # for cls token
475
+ delta_y1.extend([0]) # for separator token
476
+ delta_y1 = delta_y1[:-1] # there is no separator token in the end
477
+ delta_y1 = self.tokenizer._padding(delta_y1, 0, dtype=torch.float32)
478
+ delta_y2 = 1 - delta_y1
479
+
480
+ if not per_token_class:
481
+ target_polygon_labels = [polygons_label[i] for i in poly_indices] # polygons_label[:count_polys]
482
+ input_polygon_labels = torch.tensor(target_polygon_labels.copy(), dtype=torch.long)
483
+ else:
484
+ target_polygon_labels = []
485
+ for i in poly_indices:
486
+ poly, poly_label = quant_poly[i], polygons_label[i]
487
+ target_polygon_labels.extend([poly_label] * len(poly))
488
+ target_polygon_labels.append(self.semantic_classes - 1) # undefined class for <sep> and <eos> token
489
+ input_polygon_labels = torch.tensor(
490
+ [self.semantic_classes - 1] + target_polygon_labels.copy()[:-1], dtype=torch.long
491
+ ) # right shift by one: <bos>, ..., <coord>
492
+
493
+ max_label_length = self.tokenizer.seq_len
494
+ if len(polygons_label) < max_label_length:
495
+ target_polygon_labels.extend([-1] * (max_label_length - len(target_polygon_labels)))
496
+
497
+ target_polygon_labels = torch.tensor(target_polygon_labels, dtype=torch.long)
498
+
499
+ return {
500
+ "delta_x1": delta_x1,
501
+ "delta_x2": delta_x2,
502
+ "delta_y1": delta_y1,
503
+ "delta_y2": delta_y2,
504
+ "seq11": seq11,
505
+ "seq21": seq21,
506
+ "seq12": seq12,
507
+ "seq22": seq22,
508
+ "target_seq": target_seq,
509
+ "token_labels": token_labels,
510
+ "mask": mask,
511
+ "target_polygon_labels": target_polygon_labels,
512
+ "input_polygon_labels": input_polygon_labels,
513
+ }
514
+
515
+
516
+ def make_poly_transforms(dataset_name, image_set, image_size=256, disable_image_transform=False):
517
+
518
+ trans_list = []
519
+ if dataset_name in ["cubicasa", "waffle"] or (dataset_name == "r2g" and image_size != 512):
520
+ trans_list = [ResizeAndPad((image_size, image_size), pad_value=255)]
521
+
522
+ if image_set == "train":
523
+ if not disable_image_transform:
524
+ trans_list.extend(
525
+ [
526
+ T.RandomFlip(prob=0.5, horizontal=True, vertical=False),
527
+ T.RandomFlip(prob=0.5, horizontal=False, vertical=True),
528
+ T.RandomRotation([0.0, 90.0, 180.0, 270.0], expand=False, center=None, sample_style="choice"),
529
+ ]
530
+ )
531
+ return T.AugmentationList(trans_list)
532
+
533
+ if image_set == "val" or image_set == "test":
534
+ return None if len(trans_list) == 0 else T.AugmentationList(trans_list)
535
+
536
+ raise ValueError(f"unknown {image_set}")
537
+
538
+
539
+ def build(image_set, args):
540
+ root = Path(args.dataset_root)
541
+ assert root.exists(), f"provided data path {root} does not exist"
542
+
543
+ PATHS = {
544
+ "train": (root / "train", root / "annotations" / "train.json"),
545
+ "val": (root / "val", root / "annotations" / "val.json"),
546
+ "test": (root / "test", root / "annotations" / "test.json"),
547
+ }
548
+
549
+ img_folder, ann_file = PATHS[image_set]
550
+ image_transform = make_poly_transforms(
551
+ args.dataset_name,
552
+ image_set,
553
+ image_size=args.image_size,
554
+ disable_image_transform=getattr(args, "disable_image_transform", False),
555
+ )
556
+
557
+ if args.wd_only:
558
+ dataset = MultiPolyWD(
559
+ img_folder,
560
+ ann_file,
561
+ transforms=image_transform,
562
+ semantic_classes=args.semantic_classes,
563
+ dataset_name=args.dataset_name,
564
+ image_norm=args.image_norm,
565
+ poly2seq=args.poly2seq,
566
+ num_bins=args.num_bins,
567
+ seq_len=args.seq_len,
568
+ add_cls_token=args.add_cls_token,
569
+ per_token_class=args.per_token_sem_loss,
570
+ mask_format=getattr(args, "mask_format", "polygon"),
571
+ )
572
+ else:
573
+ dataset = MultiPoly(
574
+ img_folder,
575
+ ann_file,
576
+ transforms=image_transform,
577
+ semantic_classes=args.semantic_classes,
578
+ dataset_name=args.dataset_name,
579
+ image_norm=args.image_norm,
580
+ poly2seq=args.poly2seq,
581
+ num_bins=args.num_bins,
582
+ seq_len=args.seq_len,
583
+ add_cls_token=args.add_cls_token,
584
+ per_token_class=args.per_token_sem_loss,
585
+ mask_format=getattr(args, "mask_format", "polygon"),
586
+ converter_version=getattr(args, "converter_version", "v1"),
587
+ random_drop_rate=getattr(args, "random_drop_rate", 0.0),
588
+ )
589
+
590
+ return dataset
datasets/room_dropout.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import List, Optional, Tuple
3
+
4
+ import cv2
5
+ import numpy as np
6
+ from skimage.draw import polygon
7
+
8
+
9
+ class RoomDropoutStrategy:
10
+ """
11
+ Strategy for randomly dropping rooms from a density map using ground truth coordinates.
12
+
13
+ Density map: grayscale image where foreground (rooms) are white points and background is black
14
+ GT room coordinates: list of 2D points defining each room's boundary
15
+ """
16
+
17
+ def __init__(self, density_map: np.ndarray, room_coordinates: List[List[Tuple[int, int]]]):
18
+ """
19
+ Initialize the dropout strategy.
20
+
21
+ Args:
22
+ density_map: Grayscale image (H, W) where white pixels represent rooms
23
+ room_coordinates: List of rooms, each room is a list of (x, y) coordinate tuples
24
+ """
25
+ self.original_density_map = density_map.copy()
26
+ self.room_coordinates = room_coordinates
27
+ self.num_rooms = len(room_coordinates)
28
+
29
+ def create_room_masks(self) -> List[np.ndarray]:
30
+ """
31
+ Create binary masks for each room using their GT coordinates.
32
+
33
+ Returns:
34
+ List of binary masks, one for each room
35
+ """
36
+ h, w = self.original_density_map.shape
37
+ room_masks = []
38
+
39
+ for room_coords in self.room_coordinates:
40
+ mask = np.zeros((h, w), dtype=np.uint8)
41
+
42
+ if len(room_coords) >= 3: # Need at least 3 points for a polygon
43
+ # Convert coordinates to numpy array
44
+ coords = np.array(room_coords)
45
+ x_coords = coords[:, 0]
46
+ y_coords = coords[:, 1]
47
+
48
+ # Create polygon mask using skimage
49
+ rr, cc = polygon(y_coords, x_coords, shape=(h, w))
50
+ mask[rr, cc] = 1
51
+
52
+ room_masks.append(mask)
53
+
54
+ return room_masks
55
+
56
+ def drop_rooms_random(self, dropout_rate: float = 0.3, seed: Optional[int] = None) -> Tuple[np.ndarray, List[int]]:
57
+ """
58
+ Randomly drop rooms from the density map.
59
+
60
+ Args:
61
+ dropout_rate: Fraction of rooms to drop (0.0 to 1.0)
62
+ seed: Random seed for reproducibility
63
+
64
+ Returns:
65
+ Tuple of (modified_density_map, list_of_dropped_room_indices)
66
+ """
67
+ if seed is not None:
68
+ random.seed(seed)
69
+ np.random.seed(seed)
70
+
71
+ # Determine number of rooms to drop
72
+ num_to_drop = int(self.num_rooms * dropout_rate)
73
+
74
+ # Randomly select room indices to drop
75
+ room_indices = list(range(self.num_rooms))
76
+ dropped_indices = random.sample(room_indices, num_to_drop)
77
+
78
+ return self._apply_dropout(dropped_indices), dropped_indices
79
+
80
+ def drop_rooms_by_indices(self, room_indices: List[int]) -> np.ndarray:
81
+ """
82
+ Drop specific rooms by their indices.
83
+
84
+ Args:
85
+ room_indices: List of room indices to drop
86
+
87
+ Returns:
88
+ Modified density map with specified rooms removed
89
+ """
90
+ return self._apply_dropout(room_indices)
91
+
92
+ def drop_rooms_by_area(
93
+ self, min_area: Optional[int] = None, max_area: Optional[int] = None
94
+ ) -> Tuple[np.ndarray, List[int]]:
95
+ """
96
+ Drop rooms based on their area constraints.
97
+
98
+ Args:
99
+ min_area: Minimum area threshold (drop rooms smaller than this)
100
+ max_area: Maximum area threshold (drop rooms larger than this)
101
+
102
+ Returns:
103
+ Tuple of (modified_density_map, list_of_dropped_room_indices)
104
+ """
105
+ room_masks = self.create_room_masks()
106
+ dropped_indices = []
107
+
108
+ for i, mask in enumerate(room_masks):
109
+ area = np.sum(mask)
110
+
111
+ should_drop = False
112
+ if min_area is not None and area < min_area:
113
+ should_drop = True
114
+ if max_area is not None and area > max_area:
115
+ should_drop = True
116
+
117
+ if should_drop:
118
+ dropped_indices.append(i)
119
+
120
+ return self._apply_dropout(dropped_indices), dropped_indices
121
+
122
+ def _apply_dropout(self, room_indices_to_drop: List[int]) -> np.ndarray:
123
+ """
124
+ Apply dropout by removing specified rooms from the density map.
125
+
126
+ Args:
127
+ room_indices_to_drop: List of room indices to remove
128
+
129
+ Returns:
130
+ Modified density map with rooms removed
131
+ """
132
+ modified_map = self.original_density_map.copy()
133
+ room_masks = self.create_room_masks()
134
+
135
+ # Remove each specified room
136
+ for room_idx in room_indices_to_drop:
137
+ if 0 <= room_idx < len(room_masks):
138
+ mask = room_masks[room_idx]
139
+ # Set pixels in the room area to background (black/0)
140
+ modified_map[mask == 1] = 0
141
+
142
+ return modified_map
143
+
144
+ def visualize_dropout(
145
+ self, original_map: np.ndarray, modified_map: np.ndarray, dropped_indices: List[int]
146
+ ) -> np.ndarray:
147
+ """
148
+ Create a visualization showing the dropout effect.
149
+
150
+ Args:
151
+ original_map: Original density map
152
+ modified_map: Modified density map after dropout
153
+ dropped_indices: Indices of dropped rooms
154
+
155
+ Returns:
156
+ Visualization image with original and modified maps side by side
157
+ """
158
+ h, w = original_map.shape
159
+
160
+ # Create side-by-side comparison
161
+ vis = np.zeros((h, w * 2), dtype=np.uint8)
162
+ vis[:, :w] = original_map
163
+ vis[:, w:] = modified_map
164
+
165
+ # Highlight dropped rooms in red on the original map
166
+ if len(dropped_indices) > 0:
167
+ room_masks = self.create_room_masks()
168
+ vis_color = cv2.cvtColor(vis, cv2.COLOR_GRAY2BGR)
169
+
170
+ for idx in dropped_indices:
171
+ if 0 <= idx < len(room_masks):
172
+ mask = room_masks[idx]
173
+ # Highlight in red on the left (original) side
174
+ vis_color[mask == 1, 0] = 0 # Blue channel
175
+ vis_color[mask == 1, 1] = 0 # Green channel
176
+ vis_color[mask == 1, 2] = 255 # Red channel
177
+
178
+ return vis_color
179
+
180
+ return cv2.cvtColor(vis, cv2.COLOR_GRAY2BGR)
181
+
182
+
183
+ # Example usage and testing
184
+ def example_usage():
185
+ """
186
+ Example of how to use the RoomDropoutStrategy class.
187
+ """
188
+ # Create a sample density map (200x200 image)
189
+ density_map = np.zeros((200, 200), dtype=np.uint8)
190
+
191
+ # Create some sample room coordinates (rectangles and polygons)
192
+ room_coordinates = [
193
+ # Room 1: Rectangle
194
+ [(20, 20), (80, 20), (80, 60), (20, 60)],
195
+ # Room 2: Another rectangle
196
+ [(100, 30), (180, 30), (180, 80), (100, 80)],
197
+ # Room 3: L-shaped room
198
+ [(30, 100), (90, 100), (90, 130), (60, 130), (60, 160), (30, 160)],
199
+ # Room 4: Triangle
200
+ [(120, 120), (160, 120), (140, 160)],
201
+ # Room 5: Pentagon
202
+ [(50, 180), (70, 170), (90, 180), (80, 195), (40, 195)],
203
+ ]
204
+
205
+ # Fill the density map with white pixels for each room
206
+ for room_coords in room_coordinates:
207
+ coords = np.array(room_coords)
208
+ x_coords = coords[:, 0]
209
+ y_coords = coords[:, 1]
210
+
211
+ from skimage.draw import polygon
212
+
213
+ rr, cc = polygon(y_coords, x_coords, shape=density_map.shape)
214
+ density_map[rr, cc] = 255 # White pixels for rooms
215
+
216
+ # Initialize the dropout strategy
217
+ dropout_strategy = RoomDropoutStrategy(density_map, room_coordinates)
218
+
219
+ # Example 1: Random dropout
220
+ print("Example 1: Random dropout (30% of rooms)")
221
+ modified_map1, dropped_indices1 = dropout_strategy.drop_rooms_random(dropout_rate=0.3, seed=42)
222
+ print(f"Dropped rooms: {dropped_indices1}")
223
+
224
+ # Example 2: Drop specific rooms
225
+ print("\nExample 2: Drop specific rooms (indices 0 and 2)")
226
+ modified_map2 = dropout_strategy.drop_rooms_by_indices([0, 2])
227
+
228
+ # Example 3: Drop rooms by area
229
+ print("\nExample 3: Drop rooms with area > 3000 pixels")
230
+ modified_map3, dropped_indices3 = dropout_strategy.drop_rooms_by_area(max_area=3000)
231
+ print(f"Dropped rooms by area: {dropped_indices3}")
232
+
233
+ return density_map, modified_map1, modified_map2, modified_map3
234
+
235
+
236
+ if __name__ == "__main__":
237
+ example_usage()
datasets/transforms.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+
3
+ from detectron2.data import transforms as T
4
+
5
+
6
+ class Resize(T.Augmentation):
7
+ """Resize image to a fixed target size"""
8
+
9
+ def __init__(self, shape, interp=Image.BICUBIC):
10
+ """
11
+ Args:
12
+ shape: (h, w) tuple or a int
13
+ interp: PIL interpolation method
14
+ """
15
+ if isinstance(shape, int):
16
+ shape = (shape, shape)
17
+ shape = tuple(shape)
18
+ self._init(locals())
19
+
20
+ def get_transform(self, image):
21
+ return T.ResizeTransform(image.shape[0], image.shape[1], self.shape[0], self.shape[1], self.interp)
22
+
23
+
24
+ # Custom transform that resizes and then pads to fixed size
25
+ class ResizeAndPad(T.Augmentation):
26
+ def __init__(self, target_size, pad_value=0, interp=Image.BICUBIC):
27
+ super().__init__()
28
+ self.target_size = target_size # (height, width)
29
+ self.interp = interp
30
+ self.pad_value = pad_value
31
+
32
+ def get_transform(self, img):
33
+ h, w = img.shape[:2]
34
+ scale = min(self.target_size[0] / h, self.target_size[1] / w)
35
+ new_h, new_w = int(h * scale), int(w * scale)
36
+
37
+ # First resize preserving aspect ratio
38
+ resize_t = T.ResizeTransform(h, w, new_h, new_w, self.interp)
39
+
40
+ # Then pad to target size
41
+ pad_h, pad_w = self.target_size[0] - new_h, self.target_size[1] - new_w
42
+ top = pad_h // 2
43
+ left = pad_w // 2
44
+ pad_t = T.PadTransform(left, top, pad_w - left, pad_h - top, new_h, new_w, pad_value=self.pad_value)
45
+
46
+ return T.TransformList([resize_t, pad_t])
detectron2/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ from .utils.env import setup_environment
4
+
5
+ setup_environment()
6
+
7
+
8
+ # This line will be programatically read/write by setup.py.
9
+ # Leave them at the bottom of this file and don't touch them.
10
+ __version__ = "0.6"
detectron2/checkpoint/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ # File:
4
+
5
+
6
+ from fvcore.common.checkpoint import Checkpointer, PeriodicCheckpointer
7
+
8
+ from . import catalog as _UNUSED # register the handler
9
+ from .detection_checkpoint import DetectionCheckpointer
10
+
11
+ __all__ = ["Checkpointer", "PeriodicCheckpointer", "DetectionCheckpointer"]
detectron2/checkpoint/c2_model_loading.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import copy
3
+ import logging
4
+ import re
5
+ from typing import Dict, List
6
+
7
+ import torch
8
+ from tabulate import tabulate
9
+
10
+
11
+ def convert_basic_c2_names(original_keys):
12
+ """
13
+ Apply some basic name conversion to names in C2 weights.
14
+ It only deals with typical backbone models.
15
+
16
+ Args:
17
+ original_keys (list[str]):
18
+ Returns:
19
+ list[str]: The same number of strings matching those in original_keys.
20
+ """
21
+ layer_keys = copy.deepcopy(original_keys)
22
+ layer_keys = [
23
+ {"pred_b": "linear_b", "pred_w": "linear_w"}.get(k, k) for k in layer_keys
24
+ ] # some hard-coded mappings
25
+
26
+ layer_keys = [k.replace("_", ".") for k in layer_keys]
27
+ layer_keys = [re.sub("\\.b$", ".bias", k) for k in layer_keys]
28
+ layer_keys = [re.sub("\\.w$", ".weight", k) for k in layer_keys]
29
+ # Uniform both bn and gn names to "norm"
30
+ layer_keys = [re.sub("bn\\.s$", "norm.weight", k) for k in layer_keys]
31
+ layer_keys = [re.sub("bn\\.bias$", "norm.bias", k) for k in layer_keys]
32
+ layer_keys = [re.sub("bn\\.rm", "norm.running_mean", k) for k in layer_keys]
33
+ layer_keys = [re.sub("bn\\.running.mean$", "norm.running_mean", k) for k in layer_keys]
34
+ layer_keys = [re.sub("bn\\.riv$", "norm.running_var", k) for k in layer_keys]
35
+ layer_keys = [re.sub("bn\\.running.var$", "norm.running_var", k) for k in layer_keys]
36
+ layer_keys = [re.sub("bn\\.gamma$", "norm.weight", k) for k in layer_keys]
37
+ layer_keys = [re.sub("bn\\.beta$", "norm.bias", k) for k in layer_keys]
38
+ layer_keys = [re.sub("gn\\.s$", "norm.weight", k) for k in layer_keys]
39
+ layer_keys = [re.sub("gn\\.bias$", "norm.bias", k) for k in layer_keys]
40
+
41
+ # stem
42
+ layer_keys = [re.sub("^res\\.conv1\\.norm\\.", "conv1.norm.", k) for k in layer_keys]
43
+ # to avoid mis-matching with "conv1" in other components (e.g. detection head)
44
+ layer_keys = [re.sub("^conv1\\.", "stem.conv1.", k) for k in layer_keys]
45
+
46
+ # layer1-4 is used by torchvision, however we follow the C2 naming strategy (res2-5)
47
+ # layer_keys = [re.sub("^res2.", "layer1.", k) for k in layer_keys]
48
+ # layer_keys = [re.sub("^res3.", "layer2.", k) for k in layer_keys]
49
+ # layer_keys = [re.sub("^res4.", "layer3.", k) for k in layer_keys]
50
+ # layer_keys = [re.sub("^res5.", "layer4.", k) for k in layer_keys]
51
+
52
+ # blocks
53
+ layer_keys = [k.replace(".branch1.", ".shortcut.") for k in layer_keys]
54
+ layer_keys = [k.replace(".branch2a.", ".conv1.") for k in layer_keys]
55
+ layer_keys = [k.replace(".branch2b.", ".conv2.") for k in layer_keys]
56
+ layer_keys = [k.replace(".branch2c.", ".conv3.") for k in layer_keys]
57
+
58
+ # DensePose substitutions
59
+ layer_keys = [re.sub("^body.conv.fcn", "body_conv_fcn", k) for k in layer_keys]
60
+ layer_keys = [k.replace("AnnIndex.lowres", "ann_index_lowres") for k in layer_keys]
61
+ layer_keys = [k.replace("Index.UV.lowres", "index_uv_lowres") for k in layer_keys]
62
+ layer_keys = [k.replace("U.lowres", "u_lowres") for k in layer_keys]
63
+ layer_keys = [k.replace("V.lowres", "v_lowres") for k in layer_keys]
64
+ return layer_keys
65
+
66
+
67
+ def convert_c2_detectron_names(weights):
68
+ """
69
+ Map Caffe2 Detectron weight names to Detectron2 names.
70
+
71
+ Args:
72
+ weights (dict): name -> tensor
73
+
74
+ Returns:
75
+ dict: detectron2 names -> tensor
76
+ dict: detectron2 names -> C2 names
77
+ """
78
+ logger = logging.getLogger(__name__)
79
+ logger.info("Renaming Caffe2 weights ......")
80
+ original_keys = sorted(weights.keys())
81
+ layer_keys = copy.deepcopy(original_keys)
82
+
83
+ layer_keys = convert_basic_c2_names(layer_keys)
84
+
85
+ # --------------------------------------------------------------------------
86
+ # RPN hidden representation conv
87
+ # --------------------------------------------------------------------------
88
+ # FPN case
89
+ # In the C2 model, the RPN hidden layer conv is defined for FPN level 2 and then
90
+ # shared for all other levels, hence the appearance of "fpn2"
91
+ layer_keys = [k.replace("conv.rpn.fpn2", "proposal_generator.rpn_head.conv") for k in layer_keys]
92
+ # Non-FPN case
93
+ layer_keys = [k.replace("conv.rpn", "proposal_generator.rpn_head.conv") for k in layer_keys]
94
+
95
+ # --------------------------------------------------------------------------
96
+ # RPN box transformation conv
97
+ # --------------------------------------------------------------------------
98
+ # FPN case (see note above about "fpn2")
99
+ layer_keys = [k.replace("rpn.bbox.pred.fpn2", "proposal_generator.rpn_head.anchor_deltas") for k in layer_keys]
100
+ layer_keys = [
101
+ k.replace("rpn.cls.logits.fpn2", "proposal_generator.rpn_head.objectness_logits") for k in layer_keys
102
+ ]
103
+ # Non-FPN case
104
+ layer_keys = [k.replace("rpn.bbox.pred", "proposal_generator.rpn_head.anchor_deltas") for k in layer_keys]
105
+ layer_keys = [k.replace("rpn.cls.logits", "proposal_generator.rpn_head.objectness_logits") for k in layer_keys]
106
+
107
+ # --------------------------------------------------------------------------
108
+ # Fast R-CNN box head
109
+ # --------------------------------------------------------------------------
110
+ layer_keys = [re.sub("^bbox\\.pred", "bbox_pred", k) for k in layer_keys]
111
+ layer_keys = [re.sub("^cls\\.score", "cls_score", k) for k in layer_keys]
112
+ layer_keys = [re.sub("^fc6\\.", "box_head.fc1.", k) for k in layer_keys]
113
+ layer_keys = [re.sub("^fc7\\.", "box_head.fc2.", k) for k in layer_keys]
114
+ # 4conv1fc head tensor names: head_conv1_w, head_conv1_gn_s
115
+ layer_keys = [re.sub("^head\\.conv", "box_head.conv", k) for k in layer_keys]
116
+
117
+ # --------------------------------------------------------------------------
118
+ # FPN lateral and output convolutions
119
+ # --------------------------------------------------------------------------
120
+ def fpn_map(name):
121
+ """
122
+ Look for keys with the following patterns:
123
+ 1) Starts with "fpn.inner."
124
+ Example: "fpn.inner.res2.2.sum.lateral.weight"
125
+ Meaning: These are lateral pathway convolutions
126
+ 2) Starts with "fpn.res"
127
+ Example: "fpn.res2.2.sum.weight"
128
+ Meaning: These are FPN output convolutions
129
+ """
130
+ splits = name.split(".")
131
+ norm = ".norm" if "norm" in splits else ""
132
+ if name.startswith("fpn.inner."):
133
+ # splits example: ['fpn', 'inner', 'res2', '2', 'sum', 'lateral', 'weight']
134
+ stage = int(splits[2][len("res") :])
135
+ return "fpn_lateral{}{}.{}".format(stage, norm, splits[-1])
136
+ elif name.startswith("fpn.res"):
137
+ # splits example: ['fpn', 'res2', '2', 'sum', 'weight']
138
+ stage = int(splits[1][len("res") :])
139
+ return "fpn_output{}{}.{}".format(stage, norm, splits[-1])
140
+ return name
141
+
142
+ layer_keys = [fpn_map(k) for k in layer_keys]
143
+
144
+ # --------------------------------------------------------------------------
145
+ # Mask R-CNN mask head
146
+ # --------------------------------------------------------------------------
147
+ # roi_heads.StandardROIHeads case
148
+ layer_keys = [k.replace(".[mask].fcn", "mask_head.mask_fcn") for k in layer_keys]
149
+ layer_keys = [re.sub("^\\.mask\\.fcn", "mask_head.mask_fcn", k) for k in layer_keys]
150
+ layer_keys = [k.replace("mask.fcn.logits", "mask_head.predictor") for k in layer_keys]
151
+ # roi_heads.Res5ROIHeads case
152
+ layer_keys = [k.replace("conv5.mask", "mask_head.deconv") for k in layer_keys]
153
+
154
+ # --------------------------------------------------------------------------
155
+ # Keypoint R-CNN head
156
+ # --------------------------------------------------------------------------
157
+ # interestingly, the keypoint head convs have blob names that are simply "conv_fcnX"
158
+ layer_keys = [k.replace("conv.fcn", "roi_heads.keypoint_head.conv_fcn") for k in layer_keys]
159
+ layer_keys = [k.replace("kps.score.lowres", "roi_heads.keypoint_head.score_lowres") for k in layer_keys]
160
+ layer_keys = [k.replace("kps.score.", "roi_heads.keypoint_head.score.") for k in layer_keys]
161
+
162
+ # --------------------------------------------------------------------------
163
+ # Done with replacements
164
+ # --------------------------------------------------------------------------
165
+ assert len(set(layer_keys)) == len(layer_keys)
166
+ assert len(original_keys) == len(layer_keys)
167
+
168
+ new_weights = {}
169
+ new_keys_to_original_keys = {}
170
+ for orig, renamed in zip(original_keys, layer_keys):
171
+ new_keys_to_original_keys[renamed] = orig
172
+ if renamed.startswith("bbox_pred.") or renamed.startswith("mask_head.predictor."):
173
+ # remove the meaningless prediction weight for background class
174
+ new_start_idx = 4 if renamed.startswith("bbox_pred.") else 1
175
+ new_weights[renamed] = weights[orig][new_start_idx:]
176
+ logger.info(
177
+ "Remove prediction weight for background class in {}. The shape changes from "
178
+ "{} to {}.".format(renamed, tuple(weights[orig].shape), tuple(new_weights[renamed].shape))
179
+ )
180
+ elif renamed.startswith("cls_score."):
181
+ # move weights of bg class from original index 0 to last index
182
+ logger.info(
183
+ "Move classification weights for background class in {} from index 0 to "
184
+ "index {}.".format(renamed, weights[orig].shape[0] - 1)
185
+ )
186
+ new_weights[renamed] = torch.cat([weights[orig][1:], weights[orig][:1]])
187
+ else:
188
+ new_weights[renamed] = weights[orig]
189
+
190
+ return new_weights, new_keys_to_original_keys
191
+
192
+
193
+ # Note the current matching is not symmetric.
194
+ # it assumes model_state_dict will have longer names.
195
+ def align_and_update_state_dicts(model_state_dict, ckpt_state_dict, c2_conversion=True):
196
+ """
197
+ Match names between the two state-dict, and returns a new chkpt_state_dict with names
198
+ converted to match model_state_dict with heuristics. The returned dict can be later
199
+ loaded with fvcore checkpointer.
200
+ If `c2_conversion==True`, `ckpt_state_dict` is assumed to be a Caffe2
201
+ model and will be renamed at first.
202
+
203
+ Strategy: suppose that the models that we will create will have prefixes appended
204
+ to each of its keys, for example due to an extra level of nesting that the original
205
+ pre-trained weights from ImageNet won't contain. For example, model.state_dict()
206
+ might return backbone[0].body.res2.conv1.weight, while the pre-trained model contains
207
+ res2.conv1.weight. We thus want to match both parameters together.
208
+ For that, we look for each model weight, look among all loaded keys if there is one
209
+ that is a suffix of the current weight name, and use it if that's the case.
210
+ If multiple matches exist, take the one with longest size
211
+ of the corresponding name. For example, for the same model as before, the pretrained
212
+ weight file can contain both res2.conv1.weight, as well as conv1.weight. In this case,
213
+ we want to match backbone[0].body.conv1.weight to conv1.weight, and
214
+ backbone[0].body.res2.conv1.weight to res2.conv1.weight.
215
+ """
216
+ model_keys = sorted(model_state_dict.keys())
217
+ if c2_conversion:
218
+ ckpt_state_dict, original_keys = convert_c2_detectron_names(ckpt_state_dict)
219
+ # original_keys: the name in the original dict (before renaming)
220
+ else:
221
+ original_keys = {x: x for x in ckpt_state_dict.keys()}
222
+ ckpt_keys = sorted(ckpt_state_dict.keys())
223
+
224
+ def match(a, b):
225
+ # Matched ckpt_key should be a complete (starts with '.') suffix.
226
+ # For example, roi_heads.mesh_head.whatever_conv1 does not match conv1,
227
+ # but matches whatever_conv1 or mesh_head.whatever_conv1.
228
+ return a == b or a.endswith("." + b)
229
+
230
+ # get a matrix of string matches, where each (i, j) entry correspond to the size of the
231
+ # ckpt_key string, if it matches
232
+ match_matrix = [len(j) if match(i, j) else 0 for i in model_keys for j in ckpt_keys]
233
+ match_matrix = torch.as_tensor(match_matrix).view(len(model_keys), len(ckpt_keys))
234
+ # use the matched one with longest size in case of multiple matches
235
+ max_match_size, idxs = match_matrix.max(1)
236
+ # remove indices that correspond to no-match
237
+ idxs[max_match_size == 0] = -1
238
+
239
+ logger = logging.getLogger(__name__)
240
+ # matched_pairs (matched checkpoint key --> matched model key)
241
+ matched_keys = {}
242
+ result_state_dict = {}
243
+ for idx_model, idx_ckpt in enumerate(idxs.tolist()):
244
+ if idx_ckpt == -1:
245
+ continue
246
+ key_model = model_keys[idx_model]
247
+ key_ckpt = ckpt_keys[idx_ckpt]
248
+ value_ckpt = ckpt_state_dict[key_ckpt]
249
+ shape_in_model = model_state_dict[key_model].shape
250
+
251
+ if shape_in_model != value_ckpt.shape:
252
+ logger.warning(
253
+ "Shape of {} in checkpoint is {}, while shape of {} in model is {}.".format(
254
+ key_ckpt, value_ckpt.shape, key_model, shape_in_model
255
+ )
256
+ )
257
+ logger.warning("{} will not be loaded. Please double check and see if this is desired.".format(key_ckpt))
258
+ continue
259
+
260
+ assert key_model not in result_state_dict
261
+ result_state_dict[key_model] = value_ckpt
262
+ if key_ckpt in matched_keys: # already added to matched_keys
263
+ logger.error(
264
+ "Ambiguity found for {} in checkpoint!"
265
+ "It matches at least two keys in the model ({} and {}).".format(
266
+ key_ckpt, key_model, matched_keys[key_ckpt]
267
+ )
268
+ )
269
+ raise ValueError("Cannot match one checkpoint key to multiple keys in the model.")
270
+
271
+ matched_keys[key_ckpt] = key_model
272
+
273
+ # logging:
274
+ matched_model_keys = sorted(matched_keys.values())
275
+ if len(matched_model_keys) == 0:
276
+ logger.warning("No weights in checkpoint matched with model.")
277
+ return ckpt_state_dict
278
+ common_prefix = _longest_common_prefix(matched_model_keys)
279
+ rev_matched_keys = {v: k for k, v in matched_keys.items()}
280
+ original_keys = {k: original_keys[rev_matched_keys[k]] for k in matched_model_keys}
281
+
282
+ model_key_groups = _group_keys_by_module(matched_model_keys, original_keys)
283
+ table = []
284
+ memo = set()
285
+ for key_model in matched_model_keys:
286
+ if key_model in memo:
287
+ continue
288
+ if key_model in model_key_groups:
289
+ group = model_key_groups[key_model]
290
+ memo |= set(group)
291
+ shapes = [tuple(model_state_dict[k].shape) for k in group]
292
+ table.append(
293
+ (
294
+ _longest_common_prefix([k[len(common_prefix) :] for k in group]) + "*",
295
+ _group_str([original_keys[k] for k in group]),
296
+ " ".join([str(x).replace(" ", "") for x in shapes]),
297
+ )
298
+ )
299
+ else:
300
+ key_checkpoint = original_keys[key_model]
301
+ shape = str(tuple(model_state_dict[key_model].shape))
302
+ table.append((key_model[len(common_prefix) :], key_checkpoint, shape))
303
+ table_str = tabulate(table, tablefmt="pipe", headers=["Names in Model", "Names in Checkpoint", "Shapes"])
304
+ logger.info(
305
+ "Following weights matched with "
306
+ + (f"submodule {common_prefix[:-1]}" if common_prefix else "model")
307
+ + ":\n"
308
+ + table_str
309
+ )
310
+
311
+ unmatched_ckpt_keys = [k for k in ckpt_keys if k not in set(matched_keys.keys())]
312
+ for k in unmatched_ckpt_keys:
313
+ result_state_dict[k] = ckpt_state_dict[k]
314
+ return result_state_dict
315
+
316
+
317
+ def _group_keys_by_module(keys: List[str], original_names: Dict[str, str]):
318
+ """
319
+ Params in the same submodule are grouped together.
320
+
321
+ Args:
322
+ keys: names of all parameters
323
+ original_names: mapping from parameter name to their name in the checkpoint
324
+
325
+ Returns:
326
+ dict[name -> all other names in the same group]
327
+ """
328
+
329
+ def _submodule_name(key):
330
+ pos = key.rfind(".")
331
+ if pos < 0:
332
+ return None
333
+ prefix = key[: pos + 1]
334
+ return prefix
335
+
336
+ all_submodules = [_submodule_name(k) for k in keys]
337
+ all_submodules = [x for x in all_submodules if x]
338
+ all_submodules = sorted(all_submodules, key=len)
339
+
340
+ ret = {}
341
+ for prefix in all_submodules:
342
+ group = [k for k in keys if k.startswith(prefix)]
343
+ if len(group) <= 1:
344
+ continue
345
+ original_name_lcp = _longest_common_prefix_str([original_names[k] for k in group])
346
+ if len(original_name_lcp) == 0:
347
+ # don't group weights if original names don't share prefix
348
+ continue
349
+
350
+ for k in group:
351
+ if k in ret:
352
+ continue
353
+ ret[k] = group
354
+ return ret
355
+
356
+
357
+ def _longest_common_prefix(names: List[str]) -> str:
358
+ """
359
+ ["abc.zfg", "abc.zef"] -> "abc."
360
+ """
361
+ names = [n.split(".") for n in names]
362
+ m1, m2 = min(names), max(names)
363
+ ret = [a for a, b in zip(m1, m2) if a == b]
364
+ ret = ".".join(ret) + "." if len(ret) else ""
365
+ return ret
366
+
367
+
368
+ def _longest_common_prefix_str(names: List[str]) -> str:
369
+ m1, m2 = min(names), max(names)
370
+ lcp = [a for a, b in zip(m1, m2) if a == b]
371
+ lcp = "".join(lcp)
372
+ return lcp
373
+
374
+
375
+ def _group_str(names: List[str]) -> str:
376
+ """
377
+ Turn "common1", "common2", "common3" into "common{1,2,3}"
378
+ """
379
+ lcp = _longest_common_prefix_str(names)
380
+ rest = [x[len(lcp) :] for x in names]
381
+ rest = "{" + ",".join(rest) + "}"
382
+ ret = lcp + rest
383
+
384
+ # add some simplification for BN specifically
385
+ ret = ret.replace("bn_{beta,running_mean,running_var,gamma}", "bn_*")
386
+ ret = ret.replace("bn_beta,bn_running_mean,bn_running_var,bn_gamma", "bn_*")
387
+ return ret
detectron2/checkpoint/catalog.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import logging
3
+
4
+ from detectron2.utils.file_io import PathHandler, PathManager
5
+
6
+
7
+ class ModelCatalog(object):
8
+ """
9
+ Store mappings from names to third-party models.
10
+ """
11
+
12
+ S3_C2_DETECTRON_PREFIX = "https://dl.fbaipublicfiles.com/detectron"
13
+
14
+ # MSRA models have STRIDE_IN_1X1=True. False otherwise.
15
+ # NOTE: all BN models here have fused BN into an affine layer.
16
+ # As a result, you should only load them to a model with "FrozenBN".
17
+ # Loading them to a model with regular BN or SyncBN is wrong.
18
+ # Even when loaded to FrozenBN, it is still different from affine by an epsilon,
19
+ # which should be negligible for training.
20
+ # NOTE: all models here uses PIXEL_STD=[1,1,1]
21
+ # NOTE: Most of the BN models here are no longer used. We use the
22
+ # re-converted pre-trained models under detectron2 model zoo instead.
23
+ C2_IMAGENET_MODELS = {
24
+ "MSRA/R-50": "ImageNetPretrained/MSRA/R-50.pkl",
25
+ "MSRA/R-101": "ImageNetPretrained/MSRA/R-101.pkl",
26
+ "FAIR/R-50-GN": "ImageNetPretrained/47261647/R-50-GN.pkl",
27
+ "FAIR/R-101-GN": "ImageNetPretrained/47592356/R-101-GN.pkl",
28
+ "FAIR/X-101-32x8d": "ImageNetPretrained/20171220/X-101-32x8d.pkl",
29
+ "FAIR/X-101-64x4d": "ImageNetPretrained/FBResNeXt/X-101-64x4d.pkl",
30
+ "FAIR/X-152-32x8d-IN5k": "ImageNetPretrained/25093814/X-152-32x8d-IN5k.pkl",
31
+ }
32
+
33
+ C2_DETECTRON_PATH_FORMAT = "{prefix}/{url}/output/train/{dataset}/{type}/model_final.pkl" # noqa B950
34
+
35
+ C2_DATASET_COCO = "coco_2014_train%3Acoco_2014_valminusminival"
36
+ C2_DATASET_COCO_KEYPOINTS = "keypoints_coco_2014_train%3Akeypoints_coco_2014_valminusminival"
37
+
38
+ # format: {model_name} -> part of the url
39
+ C2_DETECTRON_MODELS = {
40
+ "35857197/e2e_faster_rcnn_R-50-C4_1x": "35857197/12_2017_baselines/e2e_faster_rcnn_R-50-C4_1x.yaml.01_33_49.iAX0mXvW", # noqa B950
41
+ "35857345/e2e_faster_rcnn_R-50-FPN_1x": "35857345/12_2017_baselines/e2e_faster_rcnn_R-50-FPN_1x.yaml.01_36_30.cUF7QR7I", # noqa B950
42
+ "35857890/e2e_faster_rcnn_R-101-FPN_1x": "35857890/12_2017_baselines/e2e_faster_rcnn_R-101-FPN_1x.yaml.01_38_50.sNxI7sX7", # noqa B950
43
+ "36761737/e2e_faster_rcnn_X-101-32x8d-FPN_1x": "36761737/12_2017_baselines/e2e_faster_rcnn_X-101-32x8d-FPN_1x.yaml.06_31_39.5MIHi1fZ", # noqa B950
44
+ "35858791/e2e_mask_rcnn_R-50-C4_1x": "35858791/12_2017_baselines/e2e_mask_rcnn_R-50-C4_1x.yaml.01_45_57.ZgkA7hPB", # noqa B950
45
+ "35858933/e2e_mask_rcnn_R-50-FPN_1x": "35858933/12_2017_baselines/e2e_mask_rcnn_R-50-FPN_1x.yaml.01_48_14.DzEQe4wC", # noqa B950
46
+ "35861795/e2e_mask_rcnn_R-101-FPN_1x": "35861795/12_2017_baselines/e2e_mask_rcnn_R-101-FPN_1x.yaml.02_31_37.KqyEK4tT", # noqa B950
47
+ "36761843/e2e_mask_rcnn_X-101-32x8d-FPN_1x": "36761843/12_2017_baselines/e2e_mask_rcnn_X-101-32x8d-FPN_1x.yaml.06_35_59.RZotkLKI", # noqa B950
48
+ "48616381/e2e_mask_rcnn_R-50-FPN_2x_gn": "GN/48616381/04_2018_gn_baselines/e2e_mask_rcnn_R-50-FPN_2x_gn_0416.13_23_38.bTlTI97Q", # noqa B950
49
+ "37697547/e2e_keypoint_rcnn_R-50-FPN_1x": "37697547/12_2017_baselines/e2e_keypoint_rcnn_R-50-FPN_1x.yaml.08_42_54.kdzV35ao", # noqa B950
50
+ "35998355/rpn_R-50-C4_1x": "35998355/12_2017_baselines/rpn_R-50-C4_1x.yaml.08_00_43.njH5oD9L", # noqa B950
51
+ "35998814/rpn_R-50-FPN_1x": "35998814/12_2017_baselines/rpn_R-50-FPN_1x.yaml.08_06_03.Axg0r179", # noqa B950
52
+ "36225147/fast_R-50-FPN_1x": "36225147/12_2017_baselines/fast_rcnn_R-50-FPN_1x.yaml.08_39_09.L3obSdQ2", # noqa B950
53
+ }
54
+
55
+ @staticmethod
56
+ def get(name):
57
+ if name.startswith("Caffe2Detectron/COCO"):
58
+ return ModelCatalog._get_c2_detectron_baseline(name)
59
+ if name.startswith("ImageNetPretrained/"):
60
+ return ModelCatalog._get_c2_imagenet_pretrained(name)
61
+ raise RuntimeError("model not present in the catalog: {}".format(name))
62
+
63
+ @staticmethod
64
+ def _get_c2_imagenet_pretrained(name):
65
+ prefix = ModelCatalog.S3_C2_DETECTRON_PREFIX
66
+ name = name[len("ImageNetPretrained/") :]
67
+ name = ModelCatalog.C2_IMAGENET_MODELS[name]
68
+ url = "/".join([prefix, name])
69
+ return url
70
+
71
+ @staticmethod
72
+ def _get_c2_detectron_baseline(name):
73
+ name = name[len("Caffe2Detectron/COCO/") :]
74
+ url = ModelCatalog.C2_DETECTRON_MODELS[name]
75
+ if "keypoint_rcnn" in name:
76
+ dataset = ModelCatalog.C2_DATASET_COCO_KEYPOINTS
77
+ else:
78
+ dataset = ModelCatalog.C2_DATASET_COCO
79
+
80
+ if "35998355/rpn_R-50-C4_1x" in name:
81
+ # this one model is somehow different from others ..
82
+ type = "rpn"
83
+ else:
84
+ type = "generalized_rcnn"
85
+
86
+ # Detectron C2 models are stored in the structure defined in `C2_DETECTRON_PATH_FORMAT`.
87
+ url = ModelCatalog.C2_DETECTRON_PATH_FORMAT.format(
88
+ prefix=ModelCatalog.S3_C2_DETECTRON_PREFIX, url=url, type=type, dataset=dataset
89
+ )
90
+ return url
91
+
92
+
93
+ class ModelCatalogHandler(PathHandler):
94
+ """
95
+ Resolve URL like catalog://.
96
+ """
97
+
98
+ PREFIX = "catalog://"
99
+
100
+ def _get_supported_prefixes(self):
101
+ return [self.PREFIX]
102
+
103
+ def _get_local_path(self, path, **kwargs):
104
+ logger = logging.getLogger(__name__)
105
+ catalog_path = ModelCatalog.get(path[len(self.PREFIX) :])
106
+ logger.info("Catalog entry {} points to {}".format(path, catalog_path))
107
+ return PathManager.get_local_path(catalog_path, **kwargs)
108
+
109
+ def _open(self, path, mode="r", **kwargs):
110
+ return PathManager.open(self._get_local_path(path), mode, **kwargs)
111
+
112
+
113
+ PathManager.register_handler(ModelCatalogHandler())