diff --git a/LEGAL.md b/LEGAL.md
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..6fdb1dfc367716a58d514323d9ba5391fefa50d0
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2023 Nan Xue
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/README.md b/README.md
index bd6988349f832ccae381a2419b7eedbb24b169a8..c03789130684cad6c265d4be96863e661d1aed7c 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,117 @@
----
-title: ScaleLSD
-emoji: π
-colorFrom: indigo
-colorTo: indigo
-sdk: gradio
-sdk_version: 5.33.0
-app_file: app.py
-pinned: false
----
-
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+
+
+# ScaleLSD: Scalable Deep Line Segment Detection Streamlined
+
+
+
+

+
+
+
+[Zeran Ke](https://calmke.github.io/)
1,2, [Bin Tan](https://icetttb.github.io/)
2, [Xianwei Zheng](https://jszy.whu.edu.cn/zhengxianwei/zh_CN/index.htm)
1, [Yujun Shen](https://shenyujun.github.io/)
2, [Tianfu Wu](https://research.ece.ncsu.edu/ivmcl/)
3, [Nan Xue](https://xuenan.net/)
2β
+
+
1Wuhan University
2Ant Group
3NC State University
+
+
+
+
+
+
+
+
+## βοΈ Installtion
+
+All codes are succefully tested on:
+
+- Ubuntu 22.04.5 LTS
+- CUDA 12.1
+- Python 3.10
+- Pytorch 2.5.1
+
+First clone this repo:
+
+```bash
+git clone https://github.com/ant-research/scalelsd.git
+```
+
+Then create the conda eviroment and install the dependencies:
+```bash
+conda create -n scalelsd python=3.10
+pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu121
+pip install -r requirements.txt
+pip install -e . # Install scalelsd locally
+```
+
+## π₯π Gradio Demo
+
+### Line Segment Detection
+Before you started, please download our pre-trained [models](https://huggingface.co/cherubicxn/scalelsd) and place them into the `models` folder. Then run the Gradio demo:
+```bash
+python -m gradio_demo.inference
+```
+
+### Line Matching
+Because our line matching app is built on GlueStick with our ScaleLSD, you need to install [GlueStick](https://github.com/cvg/GlueStick) and download the weights of the GlueStick model. Then run the Gradio demo:
+```bash
+pythonb -m gradio_demo.line_mat_gluestick
+```
+
+## π Inference
+
+Quickly start use our models for line segment detection by running the following command:
+```bash
+python -m predictor.predict --img $[IMAGE_PATH_OR_FODER]
+```
+
+You can also specify more params by:
+
+```bash
+python -m predictor.predict \
+ --ckpt $[MODEL_PATH] \
+ --img $[IMAGE_PATH_OR_FODER] \
+ --ext $[png/pdf/json] \
+ --threshold 10 \
+ --junction-hm 0.1 \
+ --disable-show
+```
+
+```bash
+OPTIONS:
+ --ckpt CKPT, -c CKPT
+ Path to the checkpoint file.
+ --img IMG, -i IMG Path to the image or folder containing images.
+ --ext EXT, -e EXT Output file extension (png/pdf/json).
+ --threshold THRESHOLD, -t THRESHOLD
+ Threshold for line segment detection.
+ --junction-hm JUNCTION_HM, -jh JUNCTION_HM
+ Junction heatmap threshold.
+ --num-junctions NUM_JUNCTIONS, -nj NUM_JUNCTIONS
+ Max number of junctions to detect.
+ --disable-show Disable showing the results.
+ --use_lsd Use LSD-Rectifier for line segment detection.
+ --use_nms Use Non-Maximum Suppression (NMS) for junction detection.
+```
+
+
+## π Related Third-party Projects
+
+- [HAWPv3](https://github.com/cherubicXN/hawp/tree/main)
+- [DeepLSD](https://github.com/cvg/DeepLSD)
+- [Progressive-x](https://github.com/danini/progressive-x/tree/vanishing-points)
+- [GlueStick](https://github.com/cvg/GlueStick)
+- [GlueFactory](https://github.com/cvg/glue-factory)
+- [LiMAP](https://github.com/cvg/limap)
+
+
+## π Citation
+
+If you find our work useful in your research, please consider citing:
+
+```bash
+@inproceedings{ScaleLSD,
+ title = {ScaleLSD: Scalable Deep Line Segment Detection Streamlined},
+ author = {Zeran Ke and Bin Tan and Xianwei Zheng and Yujun Shen and Tianfu Wu and Nan Xue},
+ booktitle = "IEEE Conference on Computer Vision and Pattern Recognition (CVPR)",
+ year = {2025},
+}
+```
diff --git a/gradio_demo/inference.py b/gradio_demo/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a38ec1fdee2f7fc610bea60be5d186947269545
--- /dev/null
+++ b/gradio_demo/inference.py
@@ -0,0 +1,252 @@
+import torch
+import cv2
+import os
+import gradio as gr
+import numpy as np
+import random
+from pathlib import Path
+import json
+
+from scalelsd.ssl.models.detector import ScaleLSD
+from scalelsd.base import show, WireframeGraph
+from scalelsd.ssl.misc.train_utils import fix_seeds, load_scalelsd_model
+
+# Title for the Gradio interface
+_TITLE = 'Gradio Demo of ScaleLSD for Structured Representation of Images'
+MAX_SEED = 1000
+
+
+def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
+ """random seed"""
+ if randomize_seed:
+ seed = random.randint(0, MAX_SEED)
+ return seed
+
+def stop_run():
+ """stop run"""
+ return (
+ gr.update(value="Run", variant="primary", visible=True),
+ gr.update(visible=False),
+ )
+
+def process_image(
+ input_image,
+ model_name='scalelsd-vitbase-v2-train-sa1b.pt',
+ save_name='temp_output',
+ threshold=10,
+ junction_threshold_hm=0.008,
+ num_junctions_inference=512,
+ width=512,
+ height=512,
+ line_width=2,
+ juncs_size=4,
+ whitebg=0.0,
+ draw_junctions_only=False,
+ use_lsd=False,
+ use_nms=False,
+ edge_color='orange',
+ vertex_color='Cyan',
+ output_format='png',
+ seed=0,
+ randomize_seed=False
+):
+ """core processing function for image inference"""
+ # set random seed
+ seed = int(randomize_seed_fn(seed, randomize_seed))
+ fix_seeds(seed)
+
+ # initialize model
+ ckpt = "models/" + model_name
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ model = load_scalelsd_model(ckpt, device)
+
+ # set model parameters
+ model.junction_threshold_hm = junction_threshold_hm
+ model.num_junctions_inference = num_junctions_inference
+
+ # transform input image
+ if isinstance(input_image, np.ndarray):
+ image = cv2.cvtColor(input_image, cv2.COLOR_RGB2GRAY)
+ else:
+ image = cv2.imread(input_image, 0)
+
+ # resize
+ ori_shape = image.shape[:2]
+ image_resized = cv2.resize(image.copy(), (width, height))
+ image_tensor = torch.from_numpy(image_resized).float() / 255.0
+ image_tensor = image_tensor[None, None].to('cuda')
+
+ # meta data
+ meta = {
+ 'width': ori_shape[1],
+ 'height': ori_shape[0],
+ 'filename': '',
+ 'use_lsd': use_lsd,
+ 'use_nms': use_nms,
+ }
+
+ # inference
+ with torch.no_grad():
+ outputs, _ = model(image_tensor, meta)
+ outputs = outputs[0]
+
+ # visual results
+ painter = show.painters.HAWPainter()
+ painter.confidence_threshold = threshold
+ painter.line_width = line_width
+ painter.marker_size = juncs_size
+ if whitebg > 0.0:
+ show.Canvas.white_overlay = whitebg
+
+ temp_folder = "temp_output"
+ os.makedirs(temp_folder, exist_ok=True)
+ fig_file = f"{temp_folder}/{save_name}.png"
+ with show.image_canvas(input_image, fig_file=fig_file) as ax:
+ if draw_junctions_only:
+ painter.draw_junctions(ax, outputs)
+ else:
+ painter.draw_wireframe(ax, outputs, edge_color=edge_color, vertex_color=vertex_color)
+ # read the result image
+ result_image = cv2.imread(fig_file)
+
+ if output_format != 'png':
+ fig_file = f"{temp_folder}/{save_name}.{output_format}"
+ with show.image_canvas(input_image, fig_file=fig_file) as ax:
+ if draw_junctions_only:
+ painter.draw_junctions(ax, outputs)
+ else:
+ painter.draw_wireframe(ax, outputs, edge_color=edge_color, vertex_color=vertex_color)
+
+ json_file = f"{temp_folder}/{save_name}.json"
+ indices = WireframeGraph.xyxy2indices(outputs['juncs_pred'],outputs['lines_pred'])
+ wireframe = WireframeGraph(outputs['juncs_pred'], outputs['juncs_score'], indices, outputs['lines_score'], outputs['width'], outputs['height'])
+ with open(json_file, 'w') as f:
+ json.dump(wireframe.jsonize(),f)
+
+
+ return result_image[:, :, ::-1], json_file, fig_file
+
+
+def run_demo():
+ """create the Gradio demo interface"""
+ css = """
+ #col-container {
+ margin: 0 auto;
+ max-width: 800px;
+ }
+ """
+
+ with gr.Blocks(css=css, title=_TITLE) as demo:
+ with gr.Column(elem_id="col-container"):
+ gr.Markdown(f'# {_TITLE}')
+ gr.Markdown("Detect wireframe structures in images using ScaleLSD model")
+
+ pid = gr.State()
+ figs_root = "assets/figs"
+ example_images = [os.path.join(figs_root, iname) for iname in os.listdir(figs_root)]
+
+ with gr.Row():
+ input_image = gr.Image(example_images[0], label="Input Image", type="numpy")
+ output_image = gr.Image(label="Detection Result")
+
+ with gr.Row():
+ run_btn = gr.Button(value="Run", variant="primary")
+ stop_btn = gr.Button(value="Stop", variant="stop", visible=False)
+
+ with gr.Row():
+ json_file = gr.File(label="Download JSON Output", type="filepath")
+ image_file = gr.File(label="Download Image Output", type="filepath")
+
+ with gr.Accordion("Advanced Settings", open=True):
+ with gr.Row():
+ model_name = gr.Dropdown(
+ [ckpt for ckpt in os.listdir('models') if ckpt.endswith('.pt')],
+ value='scalelsd-vitbase-v1-train-sa1b.pt',
+ label="Model Selection"
+ )
+
+ with gr.Row():
+ save_name = gr.Textbox('temp_output', label="Save Name", placeholder="Name for saving output files")
+
+ with gr.Row():
+ with gr.Column():
+ threshold = gr.Number(10, label="Line Threshold")
+ junction_threshold_hm = gr.Number(0.008, label="Junction Threshold")
+ num_junctions_inference = gr.Number(1024, label="Max Number of Junctions")
+ width = gr.Number(512, label="Input Width")
+ height = gr.Number(512, label="Input Height")
+
+ with gr.Column():
+ draw_junctions_only = gr.Checkbox(False, label="Show Junctions Only")
+ use_lsd = gr.Checkbox(False, label="Use LSD-Rectifier")
+ use_nms = gr.Checkbox(True, label="Use NMS")
+ output_format = gr.Dropdown(
+ ['png', 'jpg', 'pdf'],
+ value='png',
+ label="Output Format"
+ )
+ whitebg = gr.Slider(0.0, 1.0, value=0.7, label="White Background Opacity")
+ line_width = gr.Number(2, label="Line Width")
+ juncs_size = gr.Number(8, label="Junctions Size")
+
+ with gr.Row():
+ edge_color = gr.Dropdown(
+ ['orange', 'midnightblue', 'red', 'green'],
+ value='orange',
+ label="Edge Color"
+ )
+ vertex_color = gr.Dropdown(
+ ['Cyan', 'deeppink', 'yellow', 'purple'],
+ value='Cyan',
+ label="Vertex Color"
+ )
+
+ with gr.Row():
+ randomize_seed = gr.Checkbox(False, label="Randomize Seed")
+ seed = gr.Slider(0, MAX_SEED, value=42, step=1, label="Seed")
+
+ gr.Examples(
+ examples=example_images,
+ inputs=input_image,
+ )
+
+ # star event handlers
+ run_event = run_btn.click(
+ fn=process_image,
+ inputs=[
+ input_image,
+ model_name,
+ save_name,
+ threshold,
+ junction_threshold_hm,
+ num_junctions_inference,
+ width,
+ height,
+ line_width,
+ juncs_size,
+ whitebg,
+ draw_junctions_only,
+ use_lsd,
+ use_nms,
+ edge_color,
+ vertex_color,
+ output_format,
+ seed,
+ randomize_seed
+ ],
+ outputs=[output_image, json_file, image_file],
+ )
+
+ # stop event handlers
+ stop_btn.click(
+ fn=stop_run,
+ outputs=[run_btn, stop_btn],
+ cancels=[run_event],
+ queue=False,
+ )
+
+
+ return demo
+
+if __name__ == "__main__":
+ run_demo().launch()
diff --git a/gradio_demo/line_mat_gluestick.py b/gradio_demo/line_mat_gluestick.py
new file mode 100644
index 0000000000000000000000000000000000000000..04fbb099df7645378907a5c000478183b2e59043
--- /dev/null
+++ b/gradio_demo/line_mat_gluestick.py
@@ -0,0 +1,386 @@
+import argparse
+import os
+from os.path import join
+import sys
+import numpy as np
+import cv2
+import torch
+from matplotlib import pyplot as plt
+from tqdm import tqdm
+import gradio as gr
+import random
+
+from gluestick import batch_to_np, numpy_image_to_torch, GLUESTICK_ROOT
+from gluestick.drawing import plot_images, plot_lines, plot_color_line_matches, plot_keypoints, plot_matches
+
+from scalelsd.ssl.models.detector import ScaleLSD
+from scalelsd.base import show, WireframeGraph
+from scalelsd.ssl.datasets.transforms.homographic_transforms import sample_homography
+from scalelsd.ssl.misc.train_utils import fix_seeds
+from line_matching.two_view_pipeline import TwoViewPipeline
+
+from kornia.geometry import warp_perspective,transform_points
+
+class HADConfig:
+ num_iter = 1
+ valid_border_margin = 3
+ translation = True
+ rotation = True
+ scale = True
+ perspective = True
+ scaling_amplitude = 0.2
+ perspective_amplitude_x = 0.2
+ perspective_amplitude_y = 0.2
+ allow_artifacts = False
+ patch_ratio = 0.85
+had_cfg = HADConfig()
+
+# Evaluation config
+default_conf = {
+ 'name': 'two_view_pipeline',
+ 'use_lines': True,
+ 'extractor': {
+ 'name': 'wireframe',
+ 'sp_params': {
+ 'force_num_keypoints': False,
+ 'max_num_keypoints': 2048,
+ },
+ 'wireframe_params': {
+ 'merge_points': True,
+ 'merge_line_endpoints': True,
+ # 'merge_line_endpoints': False,
+ },
+ 'max_n_lines': 512,
+ },
+ 'matcher': {
+ 'name': 'gluestick',
+ 'weights': str(GLUESTICK_ROOT / 'resources' / 'weights' / 'checkpoint_GlueStick_MD.tar'),
+ 'trainable': False,
+ },
+ 'ground_truth': {
+ 'from_pose_depth': False,
+ }
+}
+
+# Title for the Gradio interface
+_TITLE = 'ScaleLSD-GlueStick Line Matching'
+MAX_SEED = 1000
+
+def sample_homographics(height, width):
+
+ def scale_homography(H, stride):
+ H_scaled = H.clone()
+ H_scaled[:, :, 2, :2] *= stride
+ H_scaled[:, :, :2, 2] /= stride
+ return H_scaled
+
+ homographic = sample_homography(
+ shape = (height, width),
+ perspective = had_cfg.perspective,
+ scaling = had_cfg.scale,
+ rotation = had_cfg.rotation,
+ translation = had_cfg.translation,
+ scaling_amplitude = had_cfg.scaling_amplitude,
+ perspective_amplitude_x = had_cfg.perspective_amplitude_x,
+ perspective_amplitude_y = had_cfg.perspective_amplitude_y,
+ patch_ratio = had_cfg.patch_ratio,
+ allow_artifacts = False
+ )[0]
+
+ homographic = torch.from_numpy(homographic[None]).float().cuda()
+ homographic_inv = torch.inverse(homographic)
+
+ H = {
+ 'h.1': homographic,
+ 'ih.1': homographic_inv,
+ }
+
+ return H
+
+def trans_image_with_homograpy(image):
+ h, w = image.shape[:2]
+ H = sample_homographics(height=h, width=w)
+
+ image_warped = warp_perspective(torch.Tensor(image).permute(2,0,1)[None].cuda(), H['h.1'], (h,w))
+ image_warped_ = image_warped[0].permute(1,2,0).cpu().numpy().astype(np.uint8)
+ plt.imshow(image_warped_)
+ plt.show()
+ return image_warped_
+
+def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
+ """random seed"""
+ if randomize_seed:
+ seed = random.randint(0, MAX_SEED)
+ return seed
+
+def stop_run():
+ """stop run"""
+ return (
+ gr.update(value="Run", variant="primary", visible=True),
+ gr.update(visible=False),
+ )
+
+def clear_image2():
+ return None # returning None will clear the image component
+
+def process_image(
+ input_image1='assets/figs/sa_1119229.jpg',
+ input_image2=None,
+ model_name='scalelsd-vitbase-v1-train-sa1b.pt',
+ save_name='temp',
+ threshold=5,
+ junction_threshold_hm=0.008,
+ num_junctions_inference=4096,
+ width=512,
+ height=512,
+ line_width=2,
+ juncs_size=4,
+ whitebg=1.0,
+ draw_junctions_only=False,
+ use_lsd=False,
+ use_nms=False,
+ edge_color='midnightblue',
+ vertex_color='deeppink',
+ output_format='png',
+ seed=0,
+ randomize_seed=False
+):
+ """core processing function for image inference"""
+ # set random seed
+ seed = int(randomize_seed_fn(seed, randomize_seed))
+ fix_seeds(seed)
+
+ conf = {
+ 'model_name': model_name,
+ 'threshold': threshold,
+ 'junction_threshold_hm': junction_threshold_hm,
+ 'num_junctions_inference': num_junctions_inference,
+ 'use_lsd': use_lsd,
+ 'use_nms': use_nms,
+ 'width': width,
+ 'height': height,
+ }
+
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ pipeline_model = TwoViewPipeline(default_conf).to(device).eval()
+ pipeline_model.extractor.update_conf(conf)
+
+ saveto = f'temp_output/matching_results'
+ image1 = cv2.cvtColor(input_image1, cv2.COLOR_BGR2RGB)
+ cv2.imwrite(f'{saveto}/image.png', image1)
+ input_image1 = f'{saveto}/image.png'
+ if input_image2 is None:
+ image2 = trans_image_with_homograpy(image1)
+ else:
+ image2 = cv2.cvtColor(input_image2, cv2.COLOR_BGR2RGB)
+ cv2.imwrite(f'{saveto}/image2.png', image2)
+ input_image2 = f'{saveto}/image2.png'
+
+ gray0 = cv2.imread(input_image1, 0)
+ gray1 = cv2.imread(input_image2, 0)
+
+ torch_gray0, torch_gray1 = numpy_image_to_torch(gray0), numpy_image_to_torch(gray1)
+ torch_gray0, torch_gray1 = torch_gray0.to(device)[None], torch_gray1.to(device)[None]
+
+ x = {'image0': torch_gray0, 'image1': torch_gray1}
+ pred = pipeline_model(x)
+
+ pred = batch_to_np(pred)
+ kp0, kp1 = pred["keypoints0"], pred["keypoints1"]
+ m0 = pred["matches0"]
+
+ line_seg0, line_seg1 = pred["lines0"], pred["lines1"]
+ line_matches = pred["line_matches0"]
+
+ valid_matches = m0 != -1
+ match_indices = m0[valid_matches]
+ matched_kps0 = kp0[valid_matches]
+ matched_kps1 = kp1[match_indices]
+
+ valid_matches = line_matches != -1
+ match_indices = line_matches[valid_matches]
+ matched_lines0 = line_seg0[valid_matches]
+ matched_lines1 = line_seg1[match_indices]
+
+ img0, img1 = cv2.cvtColor(gray0, cv2.COLOR_GRAY2BGR), cv2.cvtColor(gray1, cv2.COLOR_GRAY2BGR)
+
+ mat_file = f'{saveto}/{save_name}_mat.png'
+ plot_images([img0, img1], dpi=200, pad=2.0)
+ plot_lines([line_seg0, line_seg1], ps=4, lw=2)
+ plt.gcf().canvas.manager.set_window_title('Detected Lines')
+ # plt.tight_layout()
+ plt.savefig(mat_file)
+ det_image = cv2.imread(mat_file)[:,:,::-1]
+
+ det_file = f'{saveto}/{save_name}_mat.png'
+ plot_images([img0, img1], dpi=200, pad=2.0)
+ plot_color_line_matches([matched_lines0, matched_lines1], lw=3)
+ plt.gcf().canvas.manager.set_window_title('Line Matches')
+ # plt.tight_layout()
+ plt.savefig(det_file)
+ mat_image = cv2.imread(det_file)[:,:,::-1]
+
+ show.Canvas.white_overlay = whitebg
+ painter = show.painters.HAWPainter()
+
+ fig_file = f'{saveto}/{save_name}_det1.png'
+ outputs = {'lines_pred': line_seg0.reshape(-1,4)}
+ with show.image_canvas(input_image1, fig_file=fig_file) as ax:
+ painter.draw_wireframe(ax,outputs, edge_color=edge_color, vertex_color=vertex_color)
+ det1_image = cv2.imread(fig_file)[:,:,::-1]
+
+ fig_file = f'{saveto}/{save_name}_det2.png'
+ outputs = {'lines_pred': line_seg1.reshape(-1,4)}
+ with show.image_canvas(input_image2, fig_file=fig_file) as ax:
+ painter.draw_wireframe(ax,outputs, edge_color=edge_color, vertex_color=vertex_color)
+ det2_image = cv2.imread(fig_file)[:,:,::-1]
+
+ return image2[:,:,::-1], mat_image, det_image, det1_image, det2_image, mat_file, det_file
+
+
+def demo():
+ """create the Gradio demo interface"""
+ css = """
+ #col-container {
+ margin: 0 auto;
+ max-width: 800px;
+ }
+ """
+
+ with gr.Blocks(css=css, title=_TITLE) as demo:
+ with gr.Column(elem_id="col-container"):
+ gr.Markdown(f'# {_TITLE}')
+ gr.Markdown("Detect wireframe structures in images using ScaleLSD model")
+
+ pid = gr.State()
+ figs_root = "assets/mat_figs"
+ example_single = [os.path.join(figs_root, 'single', iname) for iname in os.listdir(figs_root+'/single')]
+ example_pairs = [[img, None] for img in example_single]
+ example_pairs += [
+ [os.path.join(figs_root, 'pairs', f'ref_{i}.png'),
+ os.path.join(figs_root, 'pairs', f'tgt_{i}.png')]
+ for i in [10, 72, 76, 95, 149, 151]
+ ]
+
+ with gr.Row():
+ input_image1 = gr.Image(example_pairs[0][0], label="Input Image1", type="numpy")
+ input_image2 = gr.Image(label="Input Image2", type="numpy")
+
+ with gr.Row():
+ mat_images = gr.Image(label="Matching Results")
+ with gr.Row():
+ det_images = gr.Image(label="Detection Results")
+ with gr.Row():
+ det_image1 = gr.Image(label="Detection1")
+ det_image2 = gr.Image(label="Detection2")
+
+ with gr.Row():
+ run_btn = gr.Button(value="Run", variant="primary")
+ stop_btn = gr.Button(value="Stop", variant="stop", visible=False)
+
+ with gr.Row():
+ mat_file = gr.File(label="Download Matching Result", type="filepath")
+ det_file = gr.File(label="Download Detection Result", type="filepath")
+
+ with gr.Accordion("Advanced Settings", open=True):
+ with gr.Row():
+ model_name = gr.Dropdown(
+ [ckpt for ckpt in os.listdir('models') if ckpt.endswith('.pt')],
+ value='scalelsd-vitbase-v1-train-sa1b.pt',
+ label="Model Selection"
+ )
+
+ with gr.Row():
+ save_name = gr.Textbox('temp_output', label="Save Name", placeholder="Name for saving output files")
+
+ with gr.Row():
+ with gr.Column():
+ threshold = gr.Number(10, label="Line Threshold")
+ junction_threshold_hm = gr.Number(0.008, label="Junction Threshold")
+ num_junctions_inference = gr.Number(1024, label="Max Number of Junctions")
+ width = gr.Number(512, label="Input Width")
+ height = gr.Number(512, label="Input Height")
+
+ with gr.Column():
+ draw_junctions_only = gr.Checkbox(False, label="Show Junctions Only")
+ use_lsd = gr.Checkbox(False, label="Use LSD-Rectifier")
+ use_nms = gr.Checkbox(True, label="Use NMS")
+ output_format = gr.Dropdown(
+ ['png', 'jpg', 'pdf'],
+ value='png',
+ label="Output Format"
+ )
+ whitebg = gr.Slider(0.0, 1.0, value=1.0, label="White Background Opacity")
+ line_width = gr.Number(2, label="Line Width")
+ juncs_size = gr.Number(8, label="Junctions Size")
+
+ with gr.Row():
+ edge_color = gr.Dropdown(
+ ['orange', 'midnightblue', 'red', 'green'],
+ value='midnightblue',
+ label="Edge Color"
+ )
+ vertex_color = gr.Dropdown(
+ ['Cyan', 'deeppink', 'yellow', 'purple'],
+ value='deeppink',
+ label="Vertex Color"
+ )
+
+ with gr.Row():
+ randomize_seed = gr.Checkbox(False, label="Randomize Seed")
+ seed = gr.Slider(0, MAX_SEED, value=42, step=1, label="Seed")
+
+ gr.Examples(
+ examples=example_pairs,
+ inputs=[input_image1, input_image2]
+ )
+
+ # star event handlers
+ run_event = run_btn.click(
+ fn=process_image,
+ inputs=[
+ input_image1,
+ input_image2,
+ model_name,
+ save_name,
+ threshold,
+ junction_threshold_hm,
+ num_junctions_inference,
+ width,
+ height,
+ line_width,
+ juncs_size,
+ whitebg,
+ draw_junctions_only,
+ use_lsd,
+ use_nms,
+ edge_color,
+ vertex_color,
+ output_format,
+ seed,
+ randomize_seed
+ ],
+ outputs=[input_image2, mat_images, det_images, det_image1, det_image2, mat_file, det_file],
+ )
+
+ # stop event handlers
+ stop_btn.click(
+ fn=stop_run,
+ outputs=[run_btn, stop_btn],
+ cancels=[run_event],
+ queue=False,
+ )
+
+ # When image1 changes, image2 is cleared
+ input_image1.change(
+ fn=clear_image2,
+ outputs=input_image2
+ )
+
+
+ return demo
+
+if __name__ == "__main__":
+ # ε―ε¨εΊη¨
+ demo = demo()
+ demo.launch()
diff --git a/line_matching/run.py b/line_matching/run.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f7d50c2186cb1bf9ddd2bbc6ebe6f8888c0511d
--- /dev/null
+++ b/line_matching/run.py
@@ -0,0 +1,191 @@
+import argparse
+import os
+from os.path import join
+import sys
+import numpy as np
+import cv2
+import torch
+from matplotlib import pyplot as plt
+from tqdm import tqdm
+
+from gluestick import batch_to_np, numpy_image_to_torch, GLUESTICK_ROOT
+from gluestick.drawing import plot_images, plot_lines, plot_color_line_matches, plot_keypoints, plot_matches
+from line_matching.two_view_pipeline import TwoViewPipeline
+
+from scalelsd.base import show, WireframeGraph
+from scalelsd.ssl.datasets.transforms.homographic_transforms import sample_homography
+from kornia.geometry import warp_perspective,transform_points
+
+class HADConfig:
+ num_iter = 1
+ valid_border_margin = 3
+ translation = True
+ rotation = True
+ scale = True
+ perspective = True
+ scaling_amplitude = 0.2
+ perspective_amplitude_x = 0.2
+ perspective_amplitude_y = 0.2
+ allow_artifacts = False
+ patch_ratio = 0.85
+had_cfg = HADConfig()
+
+def sample_homographics(height, width):
+
+ def scale_homography(H, stride):
+ H_scaled = H.clone()
+ H_scaled[:, :, 2, :2] *= stride
+ H_scaled[:, :, :2, 2] /= stride
+ return H_scaled
+
+ homographic = sample_homography(
+ shape = (height, width),
+ perspective = had_cfg.perspective,
+ scaling = had_cfg.scale,
+ rotation = had_cfg.rotation,
+ translation = had_cfg.translation,
+ scaling_amplitude = had_cfg.scaling_amplitude,
+ perspective_amplitude_x = had_cfg.perspective_amplitude_x,
+ perspective_amplitude_y = had_cfg.perspective_amplitude_y,
+ patch_ratio = had_cfg.patch_ratio,
+ allow_artifacts = False
+ )[0]
+
+ homographic = torch.from_numpy(homographic[None]).float().cuda()
+ homographic_inv = torch.inverse(homographic)
+
+ H = {
+ 'h.1': homographic,
+ 'ih.1': homographic_inv,
+ }
+
+ return H
+
+def trans_image_with_homograpy(image):
+ h, w = image.shape[:2]
+ H = sample_homographics(height=h, width=w)
+
+ image_warped = warp_perspective(torch.Tensor(image).permute(2,0,1)[None].cuda(), H['h.1'], (h,w))
+ image_warped_ = image_warped[0].permute(1,2,0).cpu().numpy().astype(np.uint8)
+ plt.imshow(image_warped_)
+ plt.show()
+ return image_warped_
+
+
+def main():
+ # Parse input parameters
+ parser = argparse.ArgumentParser(
+ prog='GlueStick Demo',
+ description='Demo app to show the point and line matches obtained by GlueStick')
+ parser.add_argument('-img1', default='assets/figs/sa_1119229.jpg')
+ parser.add_argument('-img2', default=None)
+ parser.add_argument('--max_pts', type=int, default=1000)
+ parser.add_argument('--max_lines', type=int, default=300)
+ parser.add_argument('--model', type=str, default='models/paper-sa1b-997pkgs-model.pt')
+ args = parser.parse_args()
+
+ # important
+ if args.img1 is None and args.img2 is None:
+ raise ValueError("Input at least one path of image1 or image2")
+
+ # Evaluation config
+ conf = {
+ 'name': 'two_view_pipeline',
+ 'use_lines': True,
+ 'extractor': {
+ 'name': 'wireframe',
+ 'sp_params': {
+ 'force_num_keypoints': False,
+ 'max_num_keypoints': args.max_pts,
+ },
+ 'wireframe_params': {
+ 'merge_points': True,
+ 'merge_line_endpoints': True,
+ # 'merge_line_endpoints': False,
+ },
+ 'max_n_lines': args.max_lines,
+ },
+ 'matcher': {
+ 'name': 'gluestick',
+ 'weights': str(GLUESTICK_ROOT / 'resources' / 'weights' / 'checkpoint_GlueStick_MD.tar'),
+ 'trainable': False,
+ },
+ 'ground_truth': {
+ 'from_pose_depth': False,
+ }
+ }
+
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
+ pipeline_model = TwoViewPipeline(conf).to(device).eval()
+ pipeline_model.extractor.update_conf(None)
+
+ saveto = f'temp_output/matching_results'
+ os.makedirs(saveto, exist_ok=True)
+
+ image1 = cv2.cvtColor(cv2.imread(args.img1), cv2.COLOR_BGR2RGB)
+ if args.img2 is None:
+ image2 = trans_image_with_homograpy(image1)
+ cv2.imwrite(f'{saveto}/warped_image.png', image2)
+ args.img2 = f'{saveto}/warped_image.png'
+
+ gray0 = cv2.imread(args.img1, 0)
+ gray1 = cv2.imread(args.img2, 0)
+
+ torch_gray0, torch_gray1 = numpy_image_to_torch(gray0), numpy_image_to_torch(gray1)
+ torch_gray0, torch_gray1 = torch_gray0.to(device)[None], torch_gray1.to(device)[None]
+
+ x = {'image0': torch_gray0, 'image1': torch_gray1}
+ pred = pipeline_model(x)
+
+ pred = batch_to_np(pred)
+ kp0, kp1 = pred["keypoints0"], pred["keypoints1"]
+ m0 = pred["matches0"]
+
+ line_seg0, line_seg1 = pred["lines0"], pred["lines1"]
+ line_matches = pred["line_matches0"]
+
+ valid_matches = m0 != -1
+ match_indices = m0[valid_matches]
+ matched_kps0 = kp0[valid_matches]
+ matched_kps1 = kp1[match_indices]
+
+ valid_matches = line_matches != -1
+ match_indices = line_matches[valid_matches]
+ matched_lines0 = line_seg0[valid_matches]
+ matched_lines1 = line_seg1[match_indices]
+
+ # Plot the matches
+ gray0 = cv2.imread(args.img1, 0)
+ gray1 = cv2.imread(args.img2, 0)
+ img0, img1 = cv2.cvtColor(gray0, cv2.COLOR_GRAY2BGR), cv2.cvtColor(gray1, cv2.COLOR_GRAY2BGR)
+
+ plot_images([img0, img1], dpi=200, pad=2.0)
+ plot_lines([line_seg0, line_seg1], ps=4, lw=2)
+ plt.gcf().canvas.manager.set_window_title('Detected Lines')
+ # plt.tight_layout()
+ plt.savefig(f'{saveto}/det.png')
+
+ plot_images([img0, img1], dpi=200, pad=2.0)
+ plot_color_line_matches([matched_lines0, matched_lines1], lw=3)
+ plt.gcf().canvas.manager.set_window_title('Line Matches')
+ # plt.tight_layout()
+ plt.savefig(f'{saveto}/mat.png')
+
+ whitebg = 1
+ show.Canvas.white_overlay = whitebg
+ painter = show.painters.HAWPainter()
+
+ fig_file = f'{saveto}/det1.png'
+ outputs = {'lines_pred': line_seg0.reshape(-1,4)}
+ with show.image_canvas(args.img1, fig_file=fig_file) as ax:
+ # painter.draw_wireframe(ax,outputs, edge_color='orange', vertex_color='Cyan')
+ painter.draw_wireframe(ax,outputs, edge_color='midnightblue', vertex_color='deeppink')
+ fig_file = f'{saveto}/det2.png'
+ outputs = {'lines_pred': line_seg1.reshape(-1,4)}
+ with show.image_canvas(args.img2, fig_file=fig_file) as ax:
+ painter.draw_wireframe(ax,outputs, edge_color='midnightblue', vertex_color='deeppink')
+
+
+
+if __name__ == '__main__':
+ main()
diff --git a/line_matching/run_list.py b/line_matching/run_list.py
new file mode 100644
index 0000000000000000000000000000000000000000..170b8b7acab892504b58e60356e5dbdf7d984e13
--- /dev/null
+++ b/line_matching/run_list.py
@@ -0,0 +1,144 @@
+import argparse
+import os
+from os.path import join
+import sys
+
+import cv2
+import torch
+from matplotlib import pyplot as plt
+from tqdm import tqdm
+
+from gluestick import batch_to_np, numpy_image_to_torch, GLUESTICK_ROOT
+from gluestick.drawing import plot_images, plot_lines, plot_color_line_matches, plot_keypoints, plot_matches
+# from gluestick.models.two_view_pipeline import TwoViewPipeline
+from line_matching.two_view_pipeline import TwoViewPipeline
+
+from scalelsd.base import show, WireframeGraph
+
+def main():
+ # Parse input parameters
+ parser = argparse.ArgumentParser(
+ prog='GlueStick Demo',
+ description='Demo app to show the point and line matches obtained by GlueStick')
+ parser.add_argument('-inum', default=None, type=int)
+ parser.add_argument('-imax', default=None, type=int)
+ parser.add_argument('-img1', default=join('resources' + os.path.sep + 'img1.jpg'))
+ parser.add_argument('-img2', default=join('resources' + os.path.sep + 'img2.jpg'))
+ parser.add_argument('--max_pts', type=int, default=1000)
+ parser.add_argument('--max_lines', type=int, default=300)
+ parser.add_argument('--model', default='scalelsd', type=str)
+ parser.add_argument('--test_root', type=str, default='data-ssl/0images-pre/')
+ args = parser.parse_args()
+
+ # Evaluation config
+ conf = {
+ 'name': 'two_view_pipeline',
+ 'use_lines': True,
+ 'extractor': {
+ 'name': 'wireframe',
+ 'sp_params': {
+ 'force_num_keypoints': False,
+ 'max_num_keypoints': args.max_pts,
+ },
+ 'wireframe_params': {
+ 'merge_points': True,
+ 'merge_line_endpoints': True,
+ # 'merge_line_endpoints': False,
+ },
+ 'max_n_lines': args.max_lines,
+ },
+ 'matcher': {
+ 'name': 'gluestick',
+ 'weights': str(GLUESTICK_ROOT / 'resources' / 'weights' / 'checkpoint_GlueStick_MD.tar'),
+ 'trainable': False,
+ },
+ 'ground_truth': {
+ 'from_pose_depth': False,
+ }
+ }
+
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
+ pipeline_model = TwoViewPipeline(conf).to(device).eval()
+
+ pipeline_model.extractor.update_conf(None)
+
+ md = args.model
+
+ root = args.test_root
+ if args.inum is not None:
+ ids = [args.inum]
+ elif args.imax is not None:
+ ids = range(args.inum, args.imax+1)
+ else:
+ l_imgs = int(len(os.listdir(root))/2)
+ ids = range(l_imgs)
+
+ for id in tqdm(ids):
+ saveto = f'temp_output/matching_results/{md}/{id}'
+ os.makedirs(saveto, exist_ok=True)
+
+ args.img1 = root + f'ref_{str(id)}.png'
+ args.img2 = root + f'tgt_{str(id)}.png'
+
+ gray0 = cv2.imread(args.img1, 0)
+ gray1 = cv2.imread(args.img2, 0)
+
+ torch_gray0, torch_gray1 = numpy_image_to_torch(gray0), numpy_image_to_torch(gray1)
+ torch_gray0, torch_gray1 = torch_gray0.to(device)[None], torch_gray1.to(device)[None]
+
+ x = {'image0': torch_gray0, 'image1': torch_gray1}
+ pred = pipeline_model(x)
+
+ pred = batch_to_np(pred)
+ kp0, kp1 = pred["keypoints0"], pred["keypoints1"]
+ m0 = pred["matches0"]
+
+ line_seg0, line_seg1 = pred["lines0"], pred["lines1"]
+ line_matches = pred["line_matches0"]
+
+ valid_matches = m0 != -1
+ match_indices = m0[valid_matches]
+ matched_kps0 = kp0[valid_matches]
+ matched_kps1 = kp1[match_indices]
+
+ valid_matches = line_matches != -1
+ match_indices = line_matches[valid_matches]
+ matched_lines0 = line_seg0[valid_matches]
+ matched_lines1 = line_seg1[match_indices]
+
+ # Plot the matches
+ gray0 = cv2.imread(args.img1, 0)
+ gray1 = cv2.imread(args.img2, 0)
+ img0, img1 = cv2.cvtColor(gray0, cv2.COLOR_GRAY2BGR), cv2.cvtColor(gray1, cv2.COLOR_GRAY2BGR)
+
+ plot_images([img0, img1], dpi=200, pad=2.0)
+ plot_lines([line_seg0, line_seg1], ps=4, lw=2)
+ plt.gcf().canvas.manager.set_window_title('Detected Lines')
+ # plt.tight_layout()
+ plt.savefig(f'{saveto}/{md}_det_{id}.png')
+
+ plot_images([img0, img1], dpi=200, pad=2.0)
+ plot_color_line_matches([matched_lines0, matched_lines1], lw=3)
+ plt.gcf().canvas.manager.set_window_title('Line Matches')
+ # plt.tight_layout()
+ plt.savefig(f'{saveto}/{md}_mat_{id}.png')
+
+ whitebg = 1
+ show.Canvas.white_overlay = whitebg
+ painter = show.painters.HAWPainter()
+
+ fig_file = f'{saveto}/{md}_det1.png'
+ outputs = {'lines_pred': line_seg0.reshape(-1,4)}
+ with show.image_canvas(args.img1, fig_file=fig_file) as ax:
+ # painter.draw_wireframe(ax,outputs, edge_color='orange', vertex_color='Cyan')
+ painter.draw_wireframe(ax,outputs, edge_color='midnightblue', vertex_color='deeppink')
+ fig_file = f'{saveto}/{md}_det2.png'
+ outputs = {'lines_pred': line_seg1.reshape(-1,4)}
+ with show.image_canvas(args.img2, fig_file=fig_file) as ax:
+ # painter.draw_wireframe(ax,outputs, edge_color='orange', vertex_color='Cyan')
+ painter.draw_wireframe(ax,outputs, edge_color='midnightblue', vertex_color='deeppink')
+
+
+
+if __name__ == '__main__':
+ main()
diff --git a/line_matching/two_view_pipeline.py b/line_matching/two_view_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..c08d5858732d3b56cf86a96d1ff1943b794edaa5
--- /dev/null
+++ b/line_matching/two_view_pipeline.py
@@ -0,0 +1,167 @@
+"""
+A two-view sparse feature matching pipeline.
+
+This model contains sub-models for each step:
+ feature extraction, feature matching, outlier filtering, pose estimation.
+Each step is optional, and the features or matches can be provided as input.
+Default: SuperPoint with nearest neighbor matching.
+
+Convention for the matches: m0[i] is the index of the keypoint in image 1
+that corresponds to the keypoint i in image 0. m0[i] = -1 if i is unmatched.
+"""
+
+import numpy as np
+import torch
+
+from gluestick import get_model
+from gluestick.models.base_model import BaseModel
+from line_matching.wireframe import SPWireframeDescriptor
+
+
+def keep_quadrant_kp_subset(keypoints, scores, descs, h, w):
+ """Keep only keypoints in one of the four quadrant of the image."""
+ h2, w2 = h // 2, w // 2
+ w_x = np.random.choice([0, w2])
+ w_y = np.random.choice([0, h2])
+ valid_mask = ((keypoints[..., 0] >= w_x)
+ & (keypoints[..., 0] < w_x + w2)
+ & (keypoints[..., 1] >= w_y)
+ & (keypoints[..., 1] < w_y + h2))
+ keypoints = keypoints[valid_mask][None]
+ scores = scores[valid_mask][None]
+ descs = descs.permute(0, 2, 1)[valid_mask].t()[None]
+ return keypoints, scores, descs
+
+
+def keep_random_kp_subset(keypoints, scores, descs, num_selected):
+ """Keep a random subset of keypoints."""
+ num_kp = keypoints.shape[1]
+ selected_kp = torch.randperm(num_kp)[:num_selected]
+ keypoints = keypoints[:, selected_kp]
+ scores = scores[:, selected_kp]
+ descs = descs[:, :, selected_kp]
+ return keypoints, scores, descs
+
+
+def keep_best_kp_subset(keypoints, scores, descs, num_selected):
+ """Keep the top num_selected best keypoints."""
+ sorted_indices = torch.sort(scores, dim=1)[1]
+ selected_kp = sorted_indices[:, -num_selected:]
+ keypoints = torch.gather(keypoints, 1,
+ selected_kp[:, :, None].repeat(1, 1, 2))
+ scores = torch.gather(scores, 1, selected_kp)
+ descs = torch.gather(descs, 2,
+ selected_kp[:, None].repeat(1, descs.shape[1], 1))
+ return keypoints, scores, descs
+
+
+class TwoViewPipeline(BaseModel):
+ default_conf = {
+ 'extractor': {
+ 'name': 'superpoint',
+ 'trainable': False,
+ },
+ 'use_lines': False,
+ 'use_points': True,
+ 'randomize_num_kp': False,
+ 'detector': {'name': None},
+ 'descriptor': {'name': None},
+ 'matcher': {'name': 'nearest_neighbor_matcher'},
+ 'filter': {'name': None},
+ 'solver': {'name': None},
+ 'ground_truth': {
+ 'from_pose_depth': False,
+ 'from_homography': False,
+ 'th_positive': 3,
+ 'th_negative': 5,
+ 'reward_positive': 1,
+ 'reward_negative': -0.25,
+ 'is_likelihood_soft': True,
+ 'p_random_occluders': 0,
+ 'n_line_sampled_pts': 50,
+ 'line_perp_dist_th': 5,
+ 'overlap_th': 0.2,
+ 'min_visibility_th': 0.5
+ },
+ }
+ required_data_keys = ['image0', 'image1']
+ strict_conf = False # need to pass new confs to children models
+ components = [
+ 'extractor', 'detector', 'descriptor', 'matcher', 'filter', 'solver']
+
+ def _init(self, conf):
+ if conf.extractor.name:
+ self.extractor = SPWireframeDescriptor(conf.extractor)
+
+ if conf.matcher.name:
+ self.matcher = get_model(conf.matcher.name)(conf.matcher)
+ else:
+ self.required_data_keys += ['matches0']
+
+ if conf.filter.name:
+ self.filter = get_model(conf.filter.name)(conf.filter)
+
+ if conf.solver.name:
+ self.solver = get_model(conf.solver.name)(conf.solver)
+
+ def _forward(self, data):
+
+ def process_siamese(data, i):
+ data_i = {k[:-1]: v for k, v in data.items() if k[-1] == i}
+ if self.conf.extractor.name:
+ pred_i = self.extractor(data_i)
+ else:
+ pred_i = {}
+ if self.conf.detector.name:
+ pred_i = self.detector(data_i)
+ else:
+ for k in ['keypoints', 'keypoint_scores', 'descriptors',
+ 'lines', 'line_scores', 'line_descriptors',
+ 'valid_lines']:
+ if k in data_i:
+ pred_i[k] = data_i[k]
+ if self.conf.descriptor.name:
+ pred_i = {
+ **pred_i, **self.descriptor({**data_i, **pred_i})}
+ return pred_i
+
+ pred0 = process_siamese(data, '0')
+ pred1 = process_siamese(data, '1')
+
+ pred = {**{k + '0': v for k, v in pred0.items()},
+ **{k + '1': v for k, v in pred1.items()}}
+
+ if self.conf.matcher.name:
+ pred = {**pred, **self.matcher({**data, **pred})}
+
+ if self.conf.filter.name:
+ pred = {**pred, **self.filter({**data, **pred})}
+
+ if self.conf.solver.name:
+ pred = {**pred, **self.solver({**data, **pred})}
+
+ return pred
+
+ def loss(self, pred, data):
+ losses = {}
+ total = 0
+ for k in self.components:
+ if self.conf[k].name:
+ try:
+ losses_ = getattr(self, k).loss(pred, {**pred, **data})
+ except NotImplementedError:
+ continue
+ losses = {**losses, **losses_}
+ total = losses_['total'] + total
+ return {**losses, 'total': total}
+
+ def metrics(self, pred, data):
+ metrics = {}
+ for k in self.components:
+ if self.conf[k].name:
+ try:
+ metrics_ = getattr(self, k).metrics(pred, {**pred, **data})
+ except NotImplementedError:
+ continue
+ metrics = {**metrics, **metrics_}
+ return metrics
diff --git a/line_matching/wireframe.py b/line_matching/wireframe.py
new file mode 100644
index 0000000000000000000000000000000000000000..66c7c04837050828973be65e0d0598d9d11f34b4
--- /dev/null
+++ b/line_matching/wireframe.py
@@ -0,0 +1,341 @@
+import numpy as np
+import torch
+from pytlsd import lsd
+from sklearn.cluster import DBSCAN
+import sys
+
+from gluestick.models.base_model import BaseModel
+from gluestick.models.superpoint import SuperPoint, sample_descriptors
+from gluestick.geometry import warp_lines_torch
+
+from pathlib import Path
+import copy, cv2
+import os, glob
+import scalelsd
+from scalelsd.ssl.models.detector import ScaleLSD
+from scalelsd.ssl.misc.train_utils import fix_seeds, load_scalelsd_model
+
+
+def lines_to_wireframe(lines, line_scores, all_descs, conf):
+ """ Given a set of lines, their score and dense descriptors,
+ merge close-by endpoints and compute a wireframe defined by
+ its junctions and connectivity.
+ Returns:
+ junctions: list of [num_junc, 2] tensors listing all wireframe junctions
+ junc_scores: list of [num_junc] tensors with the junction score
+ junc_descs: list of [dim, num_junc] tensors with the junction descriptors
+ connectivity: list of [num_junc, num_junc] bool arrays with True when 2 junctions are connected
+ new_lines: the new set of [b_size, num_lines, 2, 2] lines
+ lines_junc_idx: a [b_size, num_lines, 2] tensor with the indices of the junctions of each endpoint
+ num_true_junctions: a list of the number of valid junctions for each image in the batch,
+ i.e. before filling with random ones
+ """
+ b_size, _, _, _ = all_descs.shape
+ device = lines.device
+ endpoints = lines.reshape(b_size, -1, 2)
+
+ (junctions, junc_scores, junc_descs, connectivity, new_lines,
+ lines_junc_idx, num_true_junctions) = [], [], [], [], [], [], []
+ for bs in range(b_size):
+ # Cluster the junctions that are close-by
+ db = DBSCAN(eps=conf.nms_radius, min_samples=1).fit(
+ endpoints[bs].cpu().numpy())
+ clusters = db.labels_
+ n_clusters = len(set(clusters))
+ num_true_junctions.append(n_clusters)
+
+ # Compute the average junction and score for each cluster
+ clusters = torch.tensor(clusters, dtype=torch.long,
+ device=device)
+ new_junc = torch.zeros(n_clusters, 2, dtype=torch.float,
+ device=device)
+ new_junc.scatter_reduce_(0, clusters[:, None].repeat(1, 2),
+ endpoints[bs], reduce='mean',
+ include_self=False)
+ junctions.append(new_junc)
+ new_scores = torch.zeros(n_clusters, dtype=torch.float, device=device)
+ new_scores.scatter_reduce_(
+ 0, clusters, torch.repeat_interleave(line_scores[bs], 2),
+ reduce='mean', include_self=False)
+ junc_scores.append(new_scores)
+
+ # Compute the new lines
+ new_lines.append(junctions[-1][clusters].reshape(-1, 2, 2))
+ lines_junc_idx.append(clusters.reshape(-1, 2))
+
+ # Compute the junction connectivity
+ junc_connect = torch.eye(n_clusters, dtype=torch.bool,
+ device=device)
+ pairs = clusters.reshape(-1, 2) # these pairs are connected by a line
+ junc_connect[pairs[:, 0], pairs[:, 1]] = True
+ junc_connect[pairs[:, 1], pairs[:, 0]] = True
+ connectivity.append(junc_connect)
+
+ # Interpolate the new junction descriptors
+ junc_descs.append(sample_descriptors(
+ junctions[-1][None], all_descs[bs:(bs + 1)], 8)[0])
+
+ new_lines = torch.stack(new_lines, dim=0)
+ lines_junc_idx = torch.stack(lines_junc_idx, dim=0)
+ return (junctions, junc_scores, junc_descs, connectivity,
+ new_lines, lines_junc_idx, num_true_junctions)
+
+
+class SPWireframeDescriptor(BaseModel):
+ default_conf = {
+ 'sp_params': {
+ 'has_detector': True,
+ 'has_descriptor': True,
+ 'descriptor_dim': 256,
+ 'trainable': False,
+
+ # Inference
+ 'return_all': True,
+ 'sparse_outputs': True,
+ 'nms_radius': 4,
+ 'detection_threshold': 0.005,
+ 'max_num_keypoints': 1000,
+ 'force_num_keypoints': True,
+ 'remove_borders': 4,
+ },
+ 'wireframe_params': {
+ 'merge_points': True,
+ 'merge_line_endpoints': True,
+ 'nms_radius': 3,
+ 'max_n_junctions': 500,
+ },
+ 'max_n_lines': 250,
+ 'min_length': 15,
+ }
+ required_data_keys = ['image']
+
+ def _init(self, conf):
+ self.conf = conf
+ self.sp = SuperPoint(conf.sp_params)
+ self.extr_conf = {}
+
+ def detect_lsd_lines(self, x, max_n_lines=None):
+ if max_n_lines is None:
+ max_n_lines = self.conf.max_n_lines
+ lines, scores, valid_lines = [], [], []
+ for b in range(len(x)):
+ # For each image on batch
+ img = (x[b].squeeze().cpu().numpy() * 255).astype(np.uint8)
+ if max_n_lines is None:
+ b_segs = lsd(img)
+ else:
+ for s in [0.3, 0.4, 0.5, 0.7, 0.8, 1.0]:
+ b_segs = lsd(img, scale=s)
+ if len(b_segs) >= max_n_lines:
+ break
+
+ segs_length = np.linalg.norm(b_segs[:, 2:4] - b_segs[:, 0:2], axis=1)
+ # Remove short lines
+ b_segs = b_segs[segs_length >= self.conf.min_length]
+ segs_length = segs_length[segs_length >= self.conf.min_length]
+ b_scores = b_segs[:, -1] * np.sqrt(segs_length)
+ # Take the most relevant segments with
+ indices = np.argsort(-b_scores)
+ if max_n_lines is not None:
+ indices = indices[:max_n_lines]
+ lines.append(torch.from_numpy(b_segs[indices, :4].reshape(-1, 2, 2)))
+ scores.append(torch.from_numpy(b_scores[indices]))
+ valid_lines.append(torch.ones_like(scores[-1], dtype=torch.bool))
+
+ lines = torch.stack(lines).to(x)
+ scores = torch.stack(scores).to(x)
+ valid_lines = torch.stack(valid_lines).to(x.device)
+ return lines, scores, valid_lines
+
+ def update_conf(self, conf):
+ self.extr_conf = conf
+
+ def _forward(self, data):
+ b_size, _, h, w = data['image'].shape
+ device = data['image'].device
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ if not self.conf.sp_params.force_num_keypoints:
+ assert b_size == 1, "Only batch size of 1 accepted for non padded inputs"
+
+ # Line detection
+ if 'lines' not in data or 'line_scores' not in data:
+ if self.extr_conf is None:
+ ckpt = 'models/scalelsd-vitbase-v1-train-sa1b.pt'
+ model = load_scalelsd_model(ckpt, device)
+ model.junction_threshold_hm = 0.008
+ threshold = 5
+ model.num_junctions_inference = 4096
+ size = 512
+ image = data['image']
+ image_size = image.shape[-2:]
+ image_np = image[0,0].cpu().numpy()
+ image_cp = copy.deepcopy(image_np)
+ image_torch = torch.from_numpy(cv2.resize(image_cp, (size, size))).float()
+ image_cuda = image_torch[None,None].to(device)
+ meta = {
+ 'width': image_size[1],
+ 'height':image_size[0],
+ 'filename': '',
+ 'use_lsd': False,
+ 'use_nms': False,
+ }
+ outputs, _ = model(image_cuda, meta)
+ lines = outputs[0]['lines_pred']
+ line_scores = outputs[0]['lines_score']
+ lines = lines[line_scores>=threshold]
+ line_scores = line_scores[line_scores>=threshold][None]
+ elif self.extr_conf['model_name'] != 'lsd':
+ # initialize model
+ ckpt = "models/" + self.extr_conf['model_name']
+ model = load_scalelsd_model(ckpt, device)
+ # set model parameters
+ model.junction_threshold_hm = self.extr_conf['junction_threshold_hm']
+ model.num_junctions_inference = self.extr_conf['num_junctions_inference']
+ width, height = self.extr_conf['width'], self.extr_conf['height']
+
+ image = data['image']
+ image_size = image.shape[-2:]
+ image_np = image[0,0].cpu().numpy()
+ image_cp = copy.deepcopy(image_np)
+ image_torch = torch.from_numpy(cv2.resize(image_cp, (width, height))).float()
+ image_cuda = image_torch[None,None].to(device)
+ meta = {
+ 'width': image_size[1],
+ 'height':image_size[0],
+ 'filename': '',
+ 'use_lsd': self.extr_conf['use_lsd'],
+ 'use_nms': self.extr_conf['use_nms'],
+ }
+ outputs, _ = model(image_cuda, meta)
+ lines = outputs[0]['lines_pred']
+ line_scores = outputs[0]['lines_score']
+ lines = lines[line_scores>=self.extr_conf['threshold']]
+ line_scores = line_scores[line_scores>=self.extr_conf['threshold']][None]
+ else:
+ if 'original_img' in data:
+ # Detect more lines, because when projecting them to the image most of them will be discarded
+ lines, line_scores, valid_lines = self.detect_lsd_lines(
+ data['original_img'], self.conf.max_n_lines * 3)
+ # Apply the same transformation that is applied in homography_adaptation
+ lines, valid_lines2 = warp_lines_torch(lines, data['H'], False, data['image'].shape[-2:])
+ valid_lines = valid_lines & valid_lines2
+ lines[~valid_lines] = -1
+ line_scores[~valid_lines] = 0
+ # Re-sort the line segments to pick the ones that are inside the image and have bigger score
+ sorted_scores, sorting_indices = torch.sort(line_scores, dim=-1, descending=True)
+ line_scores = sorted_scores[:, :self.conf.max_n_lines]
+ sorting_indices = sorting_indices[:, :self.conf.max_n_lines]
+ lines = torch.take_along_dim(lines, sorting_indices[..., None, None], 1)
+ valid_lines = torch.take_along_dim(valid_lines, sorting_indices, 1)
+ else:
+ lines, line_scores, valid_lines = self.detect_lsd_lines(data['image'],max_n_lines=1000000)
+
+ else:
+ lines, line_scores, valid_lines = data['lines'], data['line_scores'], data['valid_lines']
+ if line_scores.shape[-1] != 0:
+ line_scores /= (line_scores.new_tensor(1e-8) + line_scores.max(dim=1).values[:, None])
+
+ # SuperPoint prediction
+ pred = self.sp(data)
+
+ # Remove keypoints that are too close to line endpoints
+ if self.conf.wireframe_params.merge_points:
+ kp = pred['keypoints']
+ line_endpts = lines.reshape(b_size, -1, 2)
+ dist_pt_lines = torch.norm(
+ kp[:, :, None] - line_endpts[:, None], dim=-1)
+ # For each keypoint, mark it as valid or to remove
+ pts_to_remove = torch.any(
+ dist_pt_lines < self.conf.sp_params.nms_radius, dim=2)
+ # Simply remove them (we assume batch_size = 1 here)
+ assert len(kp) == 1
+ pred['keypoints'] = pred['keypoints'][0][~pts_to_remove[0]][None]
+ pred['keypoint_scores'] = pred['keypoint_scores'][0][~pts_to_remove[0]][None]
+ pred['descriptors'] = pred['descriptors'][0].T[~pts_to_remove[0]].T[None]
+
+ # Connect the lines together to form a wireframe
+ orig_lines = lines.clone()
+ if self.conf.wireframe_params.merge_line_endpoints and len(lines[0]) > 0:
+ # Merge first close-by endpoints to connect lines
+ (line_points, line_pts_scores, line_descs, line_association,
+ lines, lines_junc_idx, num_true_junctions) = lines_to_wireframe(
+ lines, line_scores, pred['all_descriptors'],
+ conf=self.conf.wireframe_params)
+
+ # Add the keypoints to the junctions and fill the rest with random keypoints
+ (all_points, all_scores, all_descs,
+ pl_associativity) = [], [], [], []
+ for bs in range(b_size):
+ all_points.append(torch.cat(
+ [line_points[bs], pred['keypoints'][bs]], dim=0))
+ all_scores.append(torch.cat(
+ [line_pts_scores[bs], pred['keypoint_scores'][bs]], dim=0))
+ all_descs.append(torch.cat(
+ [line_descs[bs], pred['descriptors'][bs]], dim=1))
+
+ associativity = torch.eye(len(all_points[-1]), dtype=torch.bool, device=device)
+ associativity[:num_true_junctions[bs], :num_true_junctions[bs]] = \
+ line_association[bs][:num_true_junctions[bs], :num_true_junctions[bs]]
+ pl_associativity.append(associativity)
+
+ all_points = torch.stack(all_points, dim=0)
+ all_scores = torch.stack(all_scores, dim=0)
+ all_descs = torch.stack(all_descs, dim=0)
+ pl_associativity = torch.stack(pl_associativity, dim=0)
+ else:
+ # Lines are independent
+ all_points = torch.cat([lines.reshape(b_size, -1, 2),
+ pred['keypoints']], dim=1)
+ n_pts = all_points.shape[1]
+ num_lines = lines.shape[1]
+ num_true_junctions = [num_lines * 2] * b_size
+ all_scores = torch.cat([
+ torch.repeat_interleave(line_scores, 2, dim=1),
+ pred['keypoint_scores']], dim=1)
+ pred['line_descriptors'] = self.endpoints_pooling(
+ lines, pred['all_descriptors'], (h, w))
+ all_descs = torch.cat([
+ pred['line_descriptors'].reshape(b_size, self.conf.sp_params.descriptor_dim, -1),
+ pred['descriptors']], dim=2)
+ pl_associativity = torch.eye(
+ n_pts, dtype=torch.bool,
+ device=device)[None].repeat(b_size, 1, 1)
+ lines_junc_idx = torch.arange(
+ num_lines * 2, device=device).reshape(1, -1, 2).repeat(b_size, 1, 1)
+
+ del pred['all_descriptors'] # Remove dense descriptors to save memory
+ torch.cuda.empty_cache()
+
+ return {'keypoints': all_points,
+ 'keypoint_scores': all_scores,
+ 'descriptors': all_descs,
+ 'pl_associativity': pl_associativity,
+ 'num_junctions': torch.tensor(num_true_junctions),
+ 'lines': lines,
+ 'orig_lines': orig_lines,
+ 'lines_junc_idx': lines_junc_idx,
+ 'line_scores': line_scores,
+ # 'valid_lines': valid_lines,
+ }
+
+ @staticmethod
+ def endpoints_pooling(segs, all_descriptors, img_shape):
+ assert segs.ndim == 4 and segs.shape[-2:] == (2, 2)
+ filter_shape = all_descriptors.shape[-2:]
+ scale_x = filter_shape[1] / img_shape[1]
+ scale_y = filter_shape[0] / img_shape[0]
+
+ scaled_segs = torch.round(segs * torch.tensor([scale_x, scale_y]).to(segs)).long()
+ scaled_segs[..., 0] = torch.clip(scaled_segs[..., 0], 0, filter_shape[1] - 1)
+ scaled_segs[..., 1] = torch.clip(scaled_segs[..., 1], 0, filter_shape[0] - 1)
+ line_descriptors = [all_descriptors[None, b, ..., torch.squeeze(b_segs[..., 1]), torch.squeeze(b_segs[..., 0])]
+ for b, b_segs in enumerate(scaled_segs)]
+ line_descriptors = torch.cat(line_descriptors)
+ return line_descriptors # Shape (1, 256, 308, 2)
+
+ def loss(self, pred, data):
+ raise NotImplementedError
+
+ def metrics(self, pred, data):
+ return {}
diff --git a/predictor/predict.py b/predictor/predict.py
new file mode 100644
index 0000000000000000000000000000000000000000..8245dd19d043f898a8b48e5a2eae3cd9839d16a2
--- /dev/null
+++ b/predictor/predict.py
@@ -0,0 +1,131 @@
+import torch
+import random
+import numpy as np
+import os
+import os.path as osp
+import glob
+from tqdm import tqdm
+
+from scalelsd.base import setup_logger, MetricLogger, show, WireframeGraph
+
+from scalelsd.ssl.datasets import dataset_util
+from scalelsd.ssl.models.detector import ScaleLSD
+from scalelsd.ssl.misc.train_utils import load_scalelsd_model
+
+from torch.utils.data import DataLoader
+import torch.utils.data.dataloader as torch_loader
+
+from pathlib import Path
+import argparse, yaml, logging, time, datetime, cv2, copy, sys, json
+from easydict import EasyDict
+import accelerate
+from accelerate import load_checkpoint_and_dispatch
+import matplotlib
+import matplotlib.pyplot as plt
+
+def parse_args():
+ aparser = argparse.ArgumentParser()
+ aparser.add_argument('-c', '--ckpt', default='models/scalelsd-vitbase-v1-train-sa1b.pt', type=str, help='the path for loading checkpoints')
+ aparser.add_argument('-t','--threshold', default=10,type=float)
+ aparser.add_argument('-i', '--img', required=True, type=str)
+ aparser.add_argument('--width', default=512, type=int)
+ aparser.add_argument('--height', default=512,type=int)
+ aparser.add_argument('--whitebg', default=0.0, type=float)
+ aparser.add_argument('--saveto', default=None, type=str,)
+ aparser.add_argument('-e','--ext', default='pdf', type=str, choices=['pdf','png','json','txt'])
+ aparser.add_argument('--device', default='cuda', type=str, choices=['cuda','cpu','mps'])
+ aparser.add_argument('--disable-show', default=False, action='store_true')
+ aparser.add_argument('--draw-junctions-only', default=False, action='store_true')
+ aparser.add_argument('--use_lsd', default=False, action='store_true')
+ aparser.add_argument('--use_nms', default=False, action='store_true')
+
+ ScaleLSD.cli(aparser)
+
+ args = aparser.parse_args()
+
+ ScaleLSD.configure(args)
+
+ return args
+
+
+def main():
+ args = parse_args()
+
+ model = load_scalelsd_model(args.ckpt, device=args.device)
+
+ # Set up output directory and painter
+ if args.saveto is None:
+ print('No output directory specified, saving outputs to folder: temp_output/ScaleLSD')
+ args.saveto = 'temp_output/ScaleLSD'
+ os.makedirs(args.saveto,exist_ok=True)
+
+ show.painters.HAWPainter.confidence_threshold = args.threshold
+ # show.painters.HAWPainter.line_width = 2
+ # show.painters.HAWPainter.marker_size = 4
+ show.Canvas.show = not args.disable_show
+ if args.whitebg > 0.0:
+ show.Canvas.white_overlay = args.whitebg
+ painter = show.painters.HAWPainter()
+ edge_color = 'orange' # 'midnightblue'
+ vertex_color = 'Cyan' # 'deeppink'
+
+ # Prepare images
+ all_images = []
+ if os.path.isfile(args.img) and args.img.endswith(('.jpg', '.png')):
+ all_images.append(args.img)
+ elif os.path.isdir(args.img):
+ for file in os.listdir(args.img):
+ if file.endswith(('.jpg', '.png')):
+ fname = os.path.join(args.img, file)
+ all_images.append(fname)
+ all_images = sorted(all_images)
+ else:
+ raise ValueError('Input must be a file or a directory containing images.')
+
+ # Inference
+ for fname in tqdm(all_images):
+ pname = Path(fname)
+ image = cv2.imread(fname,0)
+
+ # for resize input, default shape is [512, 512]
+ ori_shape = image.shape[:2]
+ image_cp = copy.deepcopy(image)
+ image_ = cv2.resize(image_cp, (args.width, args.height))
+ image_ = torch.from_numpy(image_).float()/255.0
+ image_ = image_[None,None].to(args.device)
+
+ meta = {
+ 'width': ori_shape[1],
+ 'height':ori_shape[0],
+ 'filename': '',
+ 'use_lsd': args.use_lsd,
+ 'use_nms': args.use_nms,
+ }
+
+ with torch.no_grad():
+ outputs, _ = model(image_, meta)
+ outputs = outputs[0]
+
+
+ if args.saveto is not None:
+
+ if args.ext in ['png', 'pdf']:
+ fig_file = osp.join(args.saveto, pname.with_suffix('.'+args.ext).name)
+ with show.image_canvas(fname, fig_file=fig_file) as ax:
+ if args.draw_junctions_only:
+ painter.draw_junctions(ax,outputs)
+ else:
+ # painter.draw_wireframe(ax,outputs)
+ painter.draw_wireframe(ax,outputs, edge_color=edge_color, vertex_color=vertex_color)
+ elif args.ext == 'json':
+ indices = WireframeGraph.xyxy2indices(outputs['juncs_pred'],outputs['lines_pred'])
+ wireframe = WireframeGraph(outputs['juncs_pred'], outputs['juncs_score'], indices, outputs['lines_score'], outputs['width'], outputs['height'])
+ outpath = osp.join(args.saveto, pname.with_suffix('.json').name)
+ with open(outpath,'w') as f:
+ json.dump(wireframe.jsonize(),f)
+ else:
+ raise ValueError('Unsupported extension: {} is not in [png, pdf, json]'.format(args.ext))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..0b7a4c8f884b4c7e5e23d17cef5b9ec283a01acf
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,21 @@
+
+opencv-python
+cython
+matplotlib
+yacs
+scikit-image
+tqdm
+python-json-logger
+h5py
+shapely
+pycolmap
+seaborn
+kornia
+easydict
+pynvml
+timm
+einops==0.7.0
+numpy==1.26.4
+gradio
+pydantic==2.10.6
+pytlsd@git+https://github.com/iago-suarez/pytlsd.git@4180ab8
diff --git a/scalelsd/.gitignore b/scalelsd/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..2fc25ad384f6e2f539a1f7bfc09e0cac7f674896
--- /dev/null
+++ b/scalelsd/.gitignore
@@ -0,0 +1,10 @@
+__pycache__/
+*/__pycache__/
+**/__pycache__/
+
+data-ssl
+exp
+exp-ssl
+temp_output
+third_party
+./models
\ No newline at end of file
diff --git a/scalelsd/__init__.py b/scalelsd/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..71c007af3673c2b2a1cedcf56b0c4812bc3f2492
--- /dev/null
+++ b/scalelsd/__init__.py
@@ -0,0 +1,2 @@
+from . import base
+from . import ssl
\ No newline at end of file
diff --git a/scalelsd/base/__init__.py b/scalelsd/base/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ba695ab8c98a07280c99cf60ed7e4023a999195
--- /dev/null
+++ b/scalelsd/base/__init__.py
@@ -0,0 +1,13 @@
+from .csrc import _C
+from . import utils
+from .utils.logger import setup_logger
+from .utils.metric_logger import MetricLogger
+from .wireframe import WireframeGraph
+
+__all__ = [
+ "_C",
+ "utils",
+ "setup_logger",
+ "MetricLogger",
+ "WireframeGraph",
+]
\ No newline at end of file
diff --git a/scalelsd/base/csrc/__init__.py b/scalelsd/base/csrc/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..45eee912d74e7cd345337d53aba7fc5b06f9cc97
--- /dev/null
+++ b/scalelsd/base/csrc/__init__.py
@@ -0,0 +1,19 @@
+from torch.utils.cpp_extension import load
+import glob
+import os.path as osp
+
+__this__ = osp.dirname(__file__)
+
+try:
+ _C = load(name='_C',sources=[
+ osp.join(__this__,'binding.cpp'),
+ osp.join(__this__,'linesegment.cu'),
+ ]
+ )
+except:
+ _C = None
+
+_C = load(name='_C', sources=[osp.join(__this__,'binding.cpp'), osp.join(__this__,'linesegment.cu')])
+__all__ = ["_C"]
+
+#_C = load(name='base._C', sources=['lltm_cuda.cpp', 'lltm_cuda_kernel.cu'])
diff --git a/scalelsd/base/csrc/binding.cpp b/scalelsd/base/csrc/binding.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..1b25ef2502d6b9be633be281bf14980c2a43064e
--- /dev/null
+++ b/scalelsd/base/csrc/binding.cpp
@@ -0,0 +1,5 @@
+#include "linesegment.h"
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("encodels", &encodels, "Encoding line segments to maps");
+}
\ No newline at end of file
diff --git a/scalelsd/base/csrc/linesegment.cu b/scalelsd/base/csrc/linesegment.cu
new file mode 100644
index 0000000000000000000000000000000000000000..cd6022d02619e15e70c7750269dc62ad401791f8
--- /dev/null
+++ b/scalelsd/base/csrc/linesegment.cu
@@ -0,0 +1,139 @@
+#include
+#include
+
+// #include
+// #include
+#include
+#include
+
+#include
+#include
+
+int const CUDA_NUM_THREADS = 1024;
+
+inline int CUDA_GET_BLOCKS(const int N) {
+ return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
+}
+
+#define CUDA_1D_KERNEL_LOOP(i, n) \
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
+ i += blockDim.x * gridDim.x)
+
+
+__global__ void encode_kernel(const int nthreads, const float* lines,
+ const int input_height, const int input_width, const int num,
+ const int height, const int width, float* map,
+ bool* label, float* tmap)
+{
+ CUDA_1D_KERNEL_LOOP(index, nthreads){
+ int w = index % width;
+ int h = (index / width) % height;
+ int x_index = h*width + w;
+ int y_index = height*width + h*width + w;
+ int ux_index = 2*height*width + h*width + w;
+ int uy_index = 3*height*width + h*width + w;
+ int vx_index = 4*height*width + h*width + w;
+ int vy_index = 5*height*width + h*width + w;
+ int label_index = h*width + w;
+
+ float px = (float) w;
+ float py = (float) h;
+ float min_dis = 1e30;
+ int minp = -1;
+ bool flagp = true;
+ for(int i = 0; i < num; ++i) {
+ float xs = (float)width /(float)input_width;
+ float ys = (float)height /(float)input_height;
+ float x1 = lines[4*i ]*xs;
+ float y1 = lines[4*i+1]*ys;
+ float x2 = lines[4*i+2]*xs;
+ float y2 = lines[4*i+3]*ys;
+
+ float dx = x2 - x1;
+ float dy = y2 - y1;
+ float ux = x1 - px;
+ float uy = y1 - py;
+ float vx = x2 - px;
+ float vy = y2 - py;
+ float norm2 = dx*dx + dy*dy;
+ bool flag = false;
+ float t = ((px-x1)*dx + (py-y1)*dy)/(norm2+1e-6);
+ if (t<=1 && t>=0.0)
+ flag = true;
+
+ t = t<0.0? 0.0:t;
+ t = t>1.0? 1.0:t;
+
+ float ax = x1 + t*(x2-x1) - px;
+ float ay = y1 + t*(y2-y1) - py;
+
+ float dis = ax*ax + ay*ay;
+ if (dis < min_dis) {
+ min_dis = dis;
+ map[x_index] = ax;
+ map[y_index] = ay;
+ float norm_u2 = ux*ux+uy*uy;
+ float norm_v2 = vx*vx+vy*vy;
+
+ if (norm_u2 < norm_v2){
+ map[ux_index] = ux;
+ map[uy_index] = uy;
+ map[vx_index] = vx;
+ map[vy_index] = vy;
+ }
+ else{
+ map[ux_index] = vx;
+ map[uy_index] = vy;
+ map[vx_index] = ux;
+ map[vy_index] = uy;
+ }
+
+ minp = i;
+ if (flag)
+ flagp = true;
+ else
+ flagp = false;
+
+ tmap[index] = t;
+ }
+ }
+ // label[label_index+minp*height*width] = flagp;
+
+ }
+}
+
+
+std::tuple lsencode_cuda(
+ const at::Tensor& lines,
+ const int input_height,
+ const int input_width,
+ const int height,
+ const int width,
+ const int num_lines)
+
+{
+ auto map = at::zeros({6,height,width}, lines.options());
+ auto tmap = at::zeros({1,height,width}, lines.options());
+ auto label = at::zeros({1,height,width}, lines.options().dtype(at::kBool));
+ auto nthreads = height*width;
+
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+
+ float* map_data = map.data();
+ float* tmap_data = tmap.data();
+ bool* label_data = label.data();
+
+ encode_kernel<<>>(
+ nthreads,
+ lines.contiguous().data(),
+ input_height, input_width,
+ num_lines,
+ height, width,
+ map_data,
+ label_data,
+ tmap_data);
+
+ // THCudaCheck(cudaGetLastError());
+
+ return std::make_tuple(map, label, tmap);
+}
\ No newline at end of file
diff --git a/scalelsd/base/csrc/linesegment.h b/scalelsd/base/csrc/linesegment.h
new file mode 100644
index 0000000000000000000000000000000000000000..b3e93237c96cd97ce1ebc564656d11d36dbf86fe
--- /dev/null
+++ b/scalelsd/base/csrc/linesegment.h
@@ -0,0 +1,26 @@
+// #pragma once
+#include
+
+std::tuple lsencode_cuda(
+ const at::Tensor& lines,
+ const int input_height,
+ const int input_width,
+ const int height,
+ const int width,
+ const int num_lines);
+
+std::tuple encodels(
+ const at::Tensor& lines,
+ const int input_height,
+ const int input_width,
+ const int height,
+ const int width,
+ const int num_lines)
+{
+ return lsencode_cuda(lines,
+ input_height,
+ input_width,
+ height,
+ width,
+ num_lines);
+}
\ No newline at end of file
diff --git a/scalelsd/base/show/__init__.py b/scalelsd/base/show/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae4a667b196cfb5d90a9b83a2bbd718cabc61403
--- /dev/null
+++ b/scalelsd/base/show/__init__.py
@@ -0,0 +1,3 @@
+from .canvas import Canvas, image_canvas, canvas
+from .painters import HAWPainter
+from .cli import cli, configure
\ No newline at end of file
diff --git a/scalelsd/base/show/canvas.py b/scalelsd/base/show/canvas.py
new file mode 100644
index 0000000000000000000000000000000000000000..11dea0b643a52ac8f2b66af480ebccc423109b09
--- /dev/null
+++ b/scalelsd/base/show/canvas.py
@@ -0,0 +1,153 @@
+from contextlib import contextmanager
+import logging
+import os
+
+from matplotlib.pyplot import figimage, margins
+import numpy as np
+import cv2
+
+try:
+ import matplotlib.pyplot as plt # pylint: disable=import-error
+
+except ModuleNotFoundError as err:
+ if err.name != 'matplotlib':
+ raise err
+ plt = None
+
+
+LOG = logging.getLogger(__name__)
+
+class Canvas:
+ """Canvas for plotting.
+ All methods expose Axes objects. To get Figure objects, you can ask the axis
+ `ax.get_figure()`.
+ """
+
+ all_images_directory = None
+ all_images_count = 0
+ show = False
+ image_width = 7.0
+ image_height = None
+ blank_dpi = 200
+ image_dpi_factor = 1.0
+ image_min_dpi = 50.0
+ out_file_extension = 'pdf'
+ white_overlay = False
+
+ @classmethod
+ def generic_name(cls):
+ if cls.all_images_directory is None:
+ return None
+ os.makedirs(cls.all_images_directory, exist_ok=True)
+
+ cls.all_images_count += 1
+ return os.path.join(cls.all_images_directory,
+ '{:04}.{}'.format(cls.all_images_count, cls.out_file_extension))
+
+ @classmethod
+ @contextmanager
+ def blank(cls, fig_file=None, *, dpi=None, nomargin=False, **kwargs):
+ if plt is None:
+ raise Exception('please install matplotlib')
+ if fig_file is None:
+ fig_file = cls.generic_name()
+
+ if dpi is None:
+ dpi = cls.blank_dpi
+
+ if 'figsize' not in kwargs:
+ kwargs['figsize'] = (10, 6)
+
+ if nomargin:
+ if 'gridspec_kw' not in kwargs:
+ kwargs['gridspec_kw'] = {}
+ kwargs['gridspec_kw']['wspace'] = 0
+ kwargs['gridspec_kw']['hspace'] = 0
+ kwargs['gridspec_kw']['left'] = 0.0
+ kwargs['gridspec_kw']['right'] = 1.0
+ kwargs['gridspec_kw']['top'] = 1.0
+ kwargs['gridspec_kw']['bottom'] = 0.0
+
+ fig, ax = plt.subplots(dpi=dpi, **kwargs)
+
+ yield ax
+
+ fig.set_tight_layout(not margins)
+ if fig_file:
+ LOG.debug('writing image to %s', fig_file)
+ fig.savefig(fig_file)
+
+ if cls.show:
+ plt.show()
+ plt.close(fig)
+
+
+ @classmethod
+ @contextmanager
+ def image(cls, image, fig_file=None, *, margin=None, **kwargs):
+ if plt is None:
+ raise Exception('please install matplotlib')
+ if fig_file is None:
+ fig_file = cls.generic_name()
+
+ if isinstance(image, str):
+ image = cv2.imread(image)[...,::-1]
+ else:
+ image = np.asarray(image)
+
+ if margin is None:
+ margin = [0.0, 0.0, 0.0, 0.0]
+ elif isinstance(margin, float):
+ margin = [margin, margin, margin, margin]
+ assert len(margin) == 4
+
+ if 'figsize' not in kwargs:
+ # compute figure size: use image ratio and take the drawable area
+ # into account that is left after subtracting margins.
+ image_ratio = image.shape[0] / image.shape[1]
+ image_area_ratio = (1.0 - margin[1] - margin[3]) / (1.0 - margin[0] - margin[2])
+ if cls.image_width is not None:
+ kwargs['figsize'] = (
+ cls.image_width,
+ cls.image_width * image_ratio / image_area_ratio
+ )
+ elif cls.image_height:
+ kwargs['figsize'] = (
+ cls.image_height * image_area_ratio / image_ratio,
+ cls.image_height
+ )
+
+ # dpi = max(cls.image_min_dpi, image.shape[1] / kwargs['figsize'][0] * cls.image_dpi_factor)
+ dpi = 200
+ # import pdb; pdb.set_trace()
+ fig = plt.figure(dpi=dpi, **kwargs)
+ ax = plt.Axes(fig, [0.0 + margin[0],
+ 0.0 + margin[1],
+ 1.0 - margin[2],
+ 1.0 - margin[3]])
+
+ ax.set_axis_off()
+ ax.set_xlim(-0.5, image.shape[1] - 0.5) # imshow uses center-pixel-coordinates
+ ax.set_ylim(image.shape[0] - 0.5, -0.5)
+ fig.add_axes(ax)
+ ax.imshow(image)
+ if cls.white_overlay:
+ white_screen(ax, cls.white_overlay)
+ yield ax
+
+ if fig_file:
+ LOG.debug('writing image to %s', fig_file)
+ fig.savefig(fig_file)
+ if cls.show:
+ plt.show()
+ import pdb;pdb.set_trace()
+ plt.close(fig)
+
+def white_screen(ax, alpha=0.9):
+ ax.add_patch(
+ plt.Rectangle((0, 0), 1, 1, transform=ax.transAxes, alpha=alpha,
+ facecolor='white')
+ )
+
+canvas = Canvas.blank
+image_canvas = Canvas.image
\ No newline at end of file
diff --git a/scalelsd/base/show/cli.py b/scalelsd/base/show/cli.py
new file mode 100644
index 0000000000000000000000000000000000000000..171407457dfde710c6f15bb3d188bec5c1213021
--- /dev/null
+++ b/scalelsd/base/show/cli.py
@@ -0,0 +1,24 @@
+# from hawp.config import defaults
+import logging
+
+from .canvas import Canvas
+from .painters import HAWPainter
+import matplotlib
+LOG = logging.getLogger(__name__)
+
+def cli(parser):
+ group = parser.add_argument_group('show')
+
+ assert not Canvas.show
+ group.add_argument('--show', default=False,action='store_true',
+ help='show every plot, i.e., call matplotlib show()')
+
+ group.add_argument('--edge-threshold', default=None, type=float,
+ help='show the wireframe edges whose confidences are greater than [edge_threshold]')
+ group.add_argument('--out-ext', default='png', type=str,
+ help='save the plot in specific format')
+def configure(args):
+ Canvas.show = args.show
+ Canvas.out_file_extension = args.out_ext
+ if args.edge_threshold is not None:
+ HAWPainter.confidence_threshold = args.edge_threshold
\ No newline at end of file
diff --git a/scalelsd/base/show/painters.py b/scalelsd/base/show/painters.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb5ca7bd89cf41b40ba695cd4da4e03558e24499
--- /dev/null
+++ b/scalelsd/base/show/painters.py
@@ -0,0 +1,80 @@
+import logging
+
+import numpy as np
+import torch
+
+
+try:
+ import matplotlib
+ import matplotlib.animation
+ import matplotlib.collections
+ import matplotlib.patches
+except ImportError:
+ matplotlib = None
+
+
+LOG = logging.getLogger(__name__)
+
+
+class HAWPainter:
+ # line_width = None
+ # marker_size = None
+ line_width = 2
+ marker_size = 4
+
+ confidence_threshold = 0.05
+
+ def __init__(self):
+
+ if self.line_width is None:
+ self.line_width = 1
+
+ if self.marker_size is None:
+ self.marker_size = max(1, int(self.line_width * 0.5))
+
+ def draw_junctions(self, ax, wireframe, *,
+ edge_color = None, vertex_color = None):
+ if wireframe is None:
+ return
+
+ if edge_color is None:
+ edge_color = 'b'
+ if vertex_color is None:
+ vertex_color = 'c'
+
+ if 'lines_score' in wireframe.keys():
+ line_segments = wireframe['lines_pred'][wireframe['lines_score']>self.confidence_threshold]
+ else:
+ line_segments = wireframe['lines_pred']
+
+ if isinstance(line_segments, torch.Tensor):
+ line_segments = line_segments.cpu().numpy()
+
+ ax.plot(line_segments[:,0],line_segments[:,1],'.',color=vertex_color)
+ ax.plot(line_segments[:,2],line_segments[:,3],'.',
+ color=vertex_color)
+ def draw_wireframe(self, ax, wireframe, *,
+ edge_color = None, vertex_color = None):
+ if wireframe is None:
+ return
+
+ if edge_color is None:
+ edge_color = 'b'
+ if vertex_color is None:
+ vertex_color = 'c'
+
+ if 'lines_score' in wireframe.keys():
+ line_segments = wireframe['lines_pred'][wireframe['lines_score']>self.confidence_threshold]
+ else:
+ line_segments = wireframe['lines_pred']
+
+ # import pdb;pdb.set_trace()
+ if isinstance(line_segments, torch.Tensor):
+ line_segments = line_segments.cpu().numpy()
+
+ # import pdb;pdb.set_trace()
+ # line_segments = wireframe.line_segments(threshold=self.confidence_threshold)
+ # line_segments = line_segments.cpu().numpy()
+ ax.plot([line_segments[:,0],line_segments[:,2]],[line_segments[:,1],line_segments[:,3]],'-',color=edge_color,linewidth=self.line_width)
+ ax.plot(line_segments[:,0],line_segments[:,1],'.',color=vertex_color,markersize=self.marker_size)
+ ax.plot(line_segments[:,2],line_segments[:,3],'.',color=vertex_color,markersize=self.marker_size)
diff --git a/scalelsd/base/utils/__init__.py b/scalelsd/base/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/scalelsd/base/utils/__init__.py
@@ -0,0 +1 @@
+
diff --git a/scalelsd/base/utils/logger.py b/scalelsd/base/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..2279763a5b465c2937b6458f9a65efb70265658b
--- /dev/null
+++ b/scalelsd/base/utils/logger.py
@@ -0,0 +1,30 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import logging
+import os
+import sys
+from pythonjsonlogger import jsonlogger
+
+
+def setup_logger(name, save_dir, out_file='log.txt', json_format=False, rank=0):
+ logger = logging.getLogger(name)
+ logger.setLevel(logging.DEBUG)
+
+ if json_format:
+ formatter = jsonlogger.JsonFormatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
+ else:
+ formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
+
+ if rank == 0:
+ ch = logging.StreamHandler(stream=sys.stdout)
+ ch.setLevel(logging.DEBUG)
+ ch.setFormatter(formatter)
+ logger.addHandler(ch)
+
+ if save_dir:
+ os.makedirs(save_dir, exist_ok=True)
+ fh = logging.FileHandler(os.path.join(save_dir, out_file))
+ fh.setLevel(logging.DEBUG)
+ fh.setFormatter(formatter)
+ logger.addHandler(fh)
+
+ return logger
diff --git a/scalelsd/base/utils/metric_logger.py b/scalelsd/base/utils/metric_logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4db625494f6d67a23f61ded7f79a63accc49c20
--- /dev/null
+++ b/scalelsd/base/utils/metric_logger.py
@@ -0,0 +1,77 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+from collections import defaultdict
+from collections import deque
+
+import torch
+
+
+class SmoothedValue(object):
+ """Track a series of values and provide access to smoothed values over a
+ window or the global series average.
+ """
+
+ def __init__(self, window_size=20):
+ self.deque = deque(maxlen=window_size)
+ self.series = []
+ self.total = 0.0
+ self.count = 0
+
+ def update(self, value):
+ self.deque.append(value)
+ self.series.append(value)
+ self.count += 1
+ self.total += value
+
+ @property
+ def median(self):
+ d = torch.tensor(list(self.deque))
+ return d.median().item()
+
+ @property
+ def avg(self):
+ d = torch.tensor(list(self.deque))
+ return d.mean().item()
+
+ @property
+ def global_avg(self):
+ return self.total / self.count
+
+
+class MetricLogger(object):
+ def __init__(self, delimiter="\t"):
+ self.meters = defaultdict(SmoothedValue)
+ self.delimiter = delimiter
+
+ def update(self, **kwargs):
+ for k, v in kwargs.items():
+ if isinstance(v, torch.Tensor):
+ v = v.item()
+ assert isinstance(v, (float, int))
+ self.meters[k].update(v)
+
+ def __getattr__(self, attr):
+ if attr in self.meters:
+ return self.meters[attr]
+ if attr in self.__dict__:
+ return self.__dict__[attr]
+ raise AttributeError("'{}' object has no attribute '{}'".format(
+ type(self).__name__, attr))
+
+ def __str__(self):
+ loss_str = []
+ keys = sorted(self.meters)
+ # for name, meter in self.meters.items():
+ for name in keys:
+ meter = self.meters[name]
+ loss_str.append(
+ "{}: {:.4f} ({:.4f})".format(name, meter.median, meter.global_avg)
+ )
+ return self.delimiter.join(loss_str)
+
+ def tensorborad(self, iteration, writter, phase='train'):
+ for name, meter in self.meters.items():
+ if 'loss' in name:
+ # writter.add_scalar('average/{}'.format(name), meter.avg, iteration)
+ writter.add_scalar('{}/global/{}'.format(phase,name), meter.global_avg, iteration)
+ # writter.add_scalar('median/{}'.format(name), meter.median, iteration)
+
diff --git a/scalelsd/base/wireframe.py b/scalelsd/base/wireframe.py
new file mode 100644
index 0000000000000000000000000000000000000000..be96e41dc375ea0bdf58231b73fe71f7bd26aa9a
--- /dev/null
+++ b/scalelsd/base/wireframe.py
@@ -0,0 +1,110 @@
+import copy
+import math
+import numpy as np
+import torch
+import json
+
+class WireframeGraph:
+ def __init__(self,
+ vertices: torch.Tensor,
+ v_confidences: torch.Tensor,
+ edges: torch.Tensor,
+ edge_weights: torch.Tensor,
+ frame_width: int,
+ frame_height: int):
+ self.vertices = vertices
+ self.v_confidences = v_confidences
+ self.edges = edges
+ self.weights = edge_weights
+ self.frame_width = frame_width
+ self.frame_height = frame_height
+
+ @classmethod
+ def xyxy2indices(cls,junctions, lines):
+ # junctions: (N,2)
+ # lines: (M,4)
+ # return: (M,2)
+ dist1 = torch.norm(junctions[None,:,:]-lines[:,None,:2],dim=-1)
+ dist2 = torch.norm(junctions[None,:,:]-lines[:,None,2:],dim=-1)
+ idx1 = torch.argmin(dist1,dim=-1)
+ idx2 = torch.argmin(dist2,dim=-1)
+ return torch.stack((idx1,idx2),dim=-1)
+ @classmethod
+ def load_json(cls, fname):
+ with open(fname,'r') as f:
+ data = json.load(f)
+
+
+ vertices = torch.tensor(data['vertices'])
+ v_confidences = torch.tensor(data['vertices-score'])
+ edges = torch.tensor(data['edges'])
+ edge_weights = torch.tensor(data['edges-weights'])
+ height = data['height']
+ width = data['width']
+
+ return WireframeGraph(vertices,v_confidences,edges,edge_weights,width,height)
+
+ @property
+ def is_empty(self):
+ for key, val in self.__dict__.items():
+ if val is None:
+ return True
+ return False
+
+ @property
+ def num_vertices(self):
+ if self.is_empty:
+ return 0
+ return self.vertices.shape[0]
+
+ @property
+ def num_edges(self):
+ if self.is_empty:
+ return 0
+ return self.edges.shape[0]
+
+
+ def line_segments(self, threshold = 0.05, device=None, to_np=False):
+ is_valid = self.weights>threshold
+ p1 = self.vertices[self.edges[is_valid,0]]
+ p2 = self.vertices[self.edges[is_valid,1]]
+ ps = self.weights[is_valid]
+
+ lines = torch.cat((p1,p2,ps[:,None]),dim=-1)
+ if device is not None:
+ lines = lines.to(device)
+ if to_np:
+ lines = lines.cpu().numpy()
+
+ return lines
+ # if device != self.device:
+
+ def rescale(self, image_width, image_height):
+ scale_x = float(image_width)/float(self.frame_width)
+ scale_y = float(image_height)/float(self.frame_height)
+
+ self.vertices[:,0] *= scale_x
+ self.vertices[:,1] *= scale_y
+ self.frame_width = image_width
+ self.frame_height = image_height
+
+ def jsonize(self):
+ return {
+ 'vertices': self.vertices.cpu().tolist(),
+ 'vertices-score': self.v_confidences.cpu().tolist(),
+ 'edges': self.edges.cpu().tolist(),
+ 'edges-weights': self.weights.cpu().tolist(),
+ 'height': self.frame_height,
+ 'width': self.frame_width,
+ }
+ def __repr__(self) -> str:
+ return "WireframeGraph\n"+\
+ "Vertices: {}\n".format(self.num_vertices)+\
+ "Edges: {}\n".format(self.num_edges,) + \
+ "Frame size (HxW): {}x{}".format(self.frame_height,self.frame_width)
+
+#graph = WireframeGraph()
+if __name__ == "__main__":
+ graph = WireframeGraph.load_json('NeuS/public_data/bmvs_clock/hawp/000.json')
+ print(graph)
+
\ No newline at end of file
diff --git a/scalelsd/encoder/__init__.py b/scalelsd/encoder/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f399ca54a11f080f9ac7bdab52e89d6bf738bdf0
--- /dev/null
+++ b/scalelsd/encoder/__init__.py
@@ -0,0 +1 @@
+from .hafm import HAFMencoder
\ No newline at end of file
diff --git a/scalelsd/encoder/hafm.py b/scalelsd/encoder/hafm.py
new file mode 100644
index 0000000000000000000000000000000000000000..afc792dcef8488ec35061b0eed89e002c51bc1ba
--- /dev/null
+++ b/scalelsd/encoder/hafm.py
@@ -0,0 +1,152 @@
+import torch
+import numpy as np
+from torch.utils.data.dataloader import default_collate
+
+from halt import _C
+
+class HAFMencoder(object):
+ def __init__(self, cfg):
+ self.dis_th = cfg.ENCODER.DIS_TH
+ self.ang_th = cfg.ENCODER.ANG_TH
+ self.num_static_pos_lines = cfg.ENCODER.NUM_STATIC_POS_LINES
+ self.num_static_neg_lines = cfg.ENCODER.NUM_STATIC_NEG_LINES
+ def __call__(self,annotations):
+ targets = []
+ metas = []
+ for ann in annotations:
+ t,m = self._process_per_image(ann)
+ targets.append(t)
+ metas.append(m)
+
+ return default_collate(targets),metas
+
+ def adjacent_matrix(self, n, edges, device):
+ mat = torch.zeros(n+1,n+1,dtype=torch.bool,device=device)
+ if edges.size(0)>0:
+ mat[edges[:,0], edges[:,1]] = 1
+ mat[edges[:,1], edges[:,0]] = 1
+ return mat
+
+ def _process_per_image(self,ann):
+ junctions = ann['junctions']
+ device = junctions.device
+ height, width = ann['height'], ann['width']
+ jmap = torch.zeros((height,width),device=device)
+ joff = torch.zeros((2,height,width),device=device,dtype=torch.float32)
+ # junctions[:,0] = junctions[:,0].clamp(min=0,max=width-1)
+ # junctions[:,1] = junctions[:,1].clamp(min=0,max=height-1)
+ xint,yint = junctions[:,0].long(), junctions[:,1].long()
+ off_x = junctions[:,0] - xint.float()-0.5
+ off_y = junctions[:,1] - yint.float()-0.5
+
+ jmap[yint,xint] = 1
+ joff[0,yint,xint] = off_x
+ joff[1,yint,xint] = off_y
+
+ edges_positive = ann['edges_positive']
+ edges_negative = ann['edges_negative']
+
+ pos_mat = self.adjacent_matrix(junctions.size(0),edges_positive,device)
+ neg_mat = self.adjacent_matrix(junctions.size(0),edges_negative,device)
+ lines = torch.cat((junctions[edges_positive[:,0]], junctions[edges_positive[:,1]]),dim=-1)
+ lines_neg = torch.cat((junctions[edges_negative[:2000,0]],junctions[edges_negative[:2000,1]]),dim=-1)
+ lmap, _, _ = _C.encodels(lines,height,width,height,width,lines.size(0))
+
+ center_points = (lines[:,:2] + lines[:,2:])/2.0
+ cmap = torch.zeros((height,width),device=device)
+ cxint, cyint = center_points[:,0].long(), center_points[:,1].long()
+ cmap[cyint,cxint] = 1
+
+ # yy,xx = torch.meshgrid(torch.arange(width,device=device),torch.arange(width,device=device))
+ # gaussian = torch.exp(-((yy[:,:,None]-center_points[None,None,:,1])**2 + (xx[:,:,None]-center_points[None,None,:,0])**2)/(2*(2*2)))
+ # cmap = gaussian.max(dim=-1)[0]
+
+ lpos = np.random.permutation(lines.cpu().numpy())[:self.num_static_pos_lines]
+ lneg = np.random.permutation(lines_neg.cpu().numpy())[:self.num_static_neg_lines]
+ # lpos = lines[torch.randperm(lines.size(0),device=device)][:self.num_static_pos_lines]
+ # lneg = lines_neg[torch.randperm(lines_neg.size(0),device=device)][:self.num_static_neg_lines]
+ lpos = torch.from_numpy(lpos).to(device)
+ lneg = torch.from_numpy(lneg).to(device)
+
+ lpre = torch.cat((lpos,lneg),dim=0)
+ _swap = (torch.rand(lpre.size(0))>0.5).to(device)
+ lpre[_swap] = lpre[_swap][:,[2,3,0,1]]
+ lpre_label = torch.cat(
+ [
+ torch.ones(lpos.size(0),device=device),
+ torch.zeros(lneg.size(0),device=device)
+ ])
+
+ meta = {
+ 'junc': junctions,
+ 'Lpos': pos_mat,
+ 'Lneg': neg_mat,
+ 'lpre': lpre,
+ 'lpre_label': lpre_label,
+ 'lines': lines,
+ }
+
+
+ dismap = torch.sqrt(lmap[0]**2+lmap[1]**2)[None]
+ def _normalize(inp):
+ mag = torch.sqrt(inp[0]*inp[0]+inp[1]*inp[1])
+ return inp/(mag+1e-6)
+ md_map = _normalize(lmap[:2])
+ st_map = _normalize(lmap[2:4])
+ ed_map = _normalize(lmap[4:])
+ st_map = lmap[2:4]
+ ed_map = lmap[4:]
+
+ md_ = md_map.reshape(2,-1).t()
+ st_ = st_map.reshape(2,-1).t()
+ ed_ = ed_map.reshape(2,-1).t()
+ Rt = torch.cat(
+ (torch.cat((md_[:,None,None,0],md_[:,None,None,1]),dim=2),
+ torch.cat((-md_[:,None,None,1], md_[:,None,None,0]),dim=2)),dim=1)
+ R = torch.cat(
+ (torch.cat((md_[:,None,None,0], -md_[:,None,None,1]),dim=2),
+ torch.cat((md_[:,None,None,1], md_[:,None,None,0]),dim=2)),dim=1)
+
+ Rtst_ = torch.matmul(Rt, st_[:,:,None]).squeeze(-1).t()
+ Rted_ = torch.matmul(Rt, ed_[:,:,None]).squeeze(-1).t()
+ swap_mask = (Rtst_[1]<0)*(Rted_[1]>0)
+ pos_ = Rtst_.clone()
+ neg_ = Rted_.clone()
+ temp = pos_[:,swap_mask]
+ pos_[:,swap_mask] = neg_[:,swap_mask]
+ neg_[:,swap_mask] = temp
+
+ pos_[0] = pos_[0].clamp(min=1e-9)
+ pos_[1] = pos_[1].clamp(min=1e-9)
+ neg_[0] = neg_[0].clamp(min=1e-9)
+ neg_[1] = neg_[1].clamp(max=-1e-9)
+
+ mask = (dismap.view(-1)<=self.dis_th).float()
+
+ pos_map = pos_.reshape(-1,height,width)
+ neg_map = neg_.reshape(-1,height,width)
+
+ md_angle = torch.atan2(md_map[1], md_map[0])
+ pos_angle = torch.atan2(pos_map[1],pos_map[0])
+ neg_angle = torch.atan2(neg_map[1],neg_map[0])
+
+ mask *= (pos_angle.reshape(-1)>self.ang_th*np.pi/2.0)
+ mask *= (neg_angle.reshape(-1)<-self.ang_th*np.pi/2.0)
+
+ pos_angle_n = pos_angle/(np.pi/2)
+ neg_angle_n = -neg_angle/(np.pi/2)
+ md_angle_n = md_angle/(np.pi*2) + 0.5
+ mask = mask.reshape(height,width)
+
+
+ hafm_ang = torch.cat((md_angle_n[None],pos_angle_n[None],neg_angle_n[None],),dim=0)
+ hafm_dis = dismap.clamp(max=self.dis_th)/self.dis_th
+ mask = mask[None]
+ target = {'jloc':jmap[None],
+ 'joff':joff,
+ 'cloc': cmap[None],
+ 'md': hafm_ang,
+ 'dis': hafm_dis,
+ 'mask': mask
+ }
+ return target, meta
\ No newline at end of file
diff --git a/scalelsd/ssl/backbones/__init__.py b/scalelsd/ssl/backbones/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..13aca1604db01f2239a6f53b911d79f43613d5de
--- /dev/null
+++ b/scalelsd/ssl/backbones/__init__.py
@@ -0,0 +1 @@
+from .build import build_backbone
\ No newline at end of file
diff --git a/scalelsd/ssl/backbones/build.py b/scalelsd/ssl/backbones/build.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f1bed422626dc29a8e0a6b56e954202281f3e17
--- /dev/null
+++ b/scalelsd/ssl/backbones/build.py
@@ -0,0 +1,28 @@
+from .dpt.models import DPTFieldModel
+
+def build_dpt(
+ basemodel = "vitb_rn50_384",
+ features=256,
+ readout = "project",
+ channels_last = False,
+ use_bn = True,
+ enable_attention_hooks = False,
+ head_size = [[3],[1],[1],[2],[2]],
+ use_layer_scale = False,
+ **kwargs):
+
+ model = DPTFieldModel(
+ features=features,
+ backbone=basemodel,
+ readout=readout,
+ channels_last=channels_last,
+ use_bn=use_bn,
+ enable_attention_hooks=enable_attention_hooks,
+ head_size=head_size,
+ use_layer_scale=use_layer_scale
+ )
+
+ return model
+
+def build_backbone(**kwargs):
+ return build_dpt(**kwargs)
diff --git a/scalelsd/ssl/backbones/dpt/__init__.py b/scalelsd/ssl/backbones/dpt/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/scalelsd/ssl/backbones/dpt/base_model.py b/scalelsd/ssl/backbones/dpt/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c2e0e93b0495f48a3405546b6fe1969be3480a2
--- /dev/null
+++ b/scalelsd/ssl/backbones/dpt/base_model.py
@@ -0,0 +1,16 @@
+import torch
+
+
+class BaseModel(torch.nn.Module):
+ def load(self, path):
+ """Load model from file.
+
+ Args:
+ path (str): file path
+ """
+ parameters = torch.load(path, map_location=torch.device("cpu"))
+
+ if "optimizer" in parameters:
+ parameters = parameters["model"]
+
+ self.load_state_dict(parameters)
diff --git a/scalelsd/ssl/backbones/dpt/blocks.py b/scalelsd/ssl/backbones/dpt/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..e03f29144f8af59666eb2e065b5010cb403884a4
--- /dev/null
+++ b/scalelsd/ssl/backbones/dpt/blocks.py
@@ -0,0 +1,388 @@
+import torch
+import torch.nn as nn
+
+from .vit import (
+ _make_pretrained_vitb_rn50_384,
+ _make_pretrained_vitl16_384,
+ _make_pretrained_vitb16_384,
+ forward_vit,
+)
+
+
+def _make_encoder(
+ backbone,
+ features,
+ use_pretrained,
+ groups=1,
+ expand=False,
+ exportable=True,
+ hooks=None,
+ use_vit_only=False,
+ use_readout="ignore",
+ enable_attention_hooks=False,
+ use_layer_scale=False,
+):
+ if backbone == "vitl16_384":
+ pretrained = _make_pretrained_vitl16_384(
+ use_pretrained,
+ hooks=hooks,
+ use_readout=use_readout,
+ enable_attention_hooks=enable_attention_hooks,
+ )
+ scratch = _make_scratch(
+ [256, 512, 1024, 1024], features, groups=groups, expand=expand
+ ) # ViT-L/16 - 85.0% Top1 (backbone)
+ elif backbone == "vitb_rn50_384":
+ pretrained = _make_pretrained_vitb_rn50_384(
+ use_pretrained,
+ hooks=hooks,
+ use_vit_only=use_vit_only,
+ use_readout=use_readout,
+ enable_attention_hooks=enable_attention_hooks,
+ use_layer_scale=use_layer_scale,
+ )
+ scratch = _make_scratch(
+ [256, 512, 768, 768], features, groups=groups, expand=expand
+ ) # ViT-H/16 - 85.0% Top1 (backbone)
+ elif backbone == "vitb16_384":
+ pretrained = _make_pretrained_vitb16_384(
+ use_pretrained,
+ hooks=hooks,
+ use_readout=use_readout,
+ enable_attention_hooks=enable_attention_hooks,
+ )
+ scratch = _make_scratch(
+ [96, 192, 384, 768], features, groups=groups, expand=expand
+ ) # ViT-B/16 - 84.6% Top1 (backbone)
+ elif backbone == "resnext101_wsl":
+ pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
+ scratch = _make_scratch(
+ [256, 512, 1024, 2048], features, groups=groups, expand=expand
+ ) # efficientnet_lite3
+ else:
+ print(f"Backbone '{backbone}' not implemented")
+ assert False
+
+ return pretrained, scratch
+
+
+def _make_scratch(in_shape, out_shape, groups=1, expand=False):
+ scratch = nn.Module()
+
+ out_shape1 = out_shape
+ out_shape2 = out_shape
+ out_shape3 = out_shape
+ out_shape4 = out_shape
+ if expand == True:
+ out_shape1 = out_shape
+ out_shape2 = out_shape * 2
+ out_shape3 = out_shape * 4
+ out_shape4 = out_shape * 8
+
+ scratch.layer1_rn = nn.Conv2d(
+ in_shape[0],
+ out_shape1,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ groups=groups,
+ )
+ scratch.layer2_rn = nn.Conv2d(
+ in_shape[1],
+ out_shape2,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ groups=groups,
+ )
+ scratch.layer3_rn = nn.Conv2d(
+ in_shape[2],
+ out_shape3,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ groups=groups,
+ )
+ scratch.layer4_rn = nn.Conv2d(
+ in_shape[3],
+ out_shape4,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ groups=groups,
+ )
+
+ return scratch
+
+
+def _make_resnet_backbone(resnet):
+ pretrained = nn.Module()
+ pretrained.layer1 = nn.Sequential(
+ resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
+ )
+
+ pretrained.layer2 = resnet.layer2
+ pretrained.layer3 = resnet.layer3
+ pretrained.layer4 = resnet.layer4
+
+ return pretrained
+
+
+def _make_pretrained_resnext101_wsl(use_pretrained):
+ resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
+ return _make_resnet_backbone(resnet)
+
+
+class Interpolate(nn.Module):
+ """Interpolation module."""
+
+ def __init__(self, scale_factor, mode, align_corners=False):
+ """Init.
+
+ Args:
+ scale_factor (float): scaling
+ mode (str): interpolation mode
+ """
+ super(Interpolate, self).__init__()
+
+ self.interp = nn.functional.interpolate
+ self.scale_factor = scale_factor
+ self.mode = mode
+ self.align_corners = align_corners
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: interpolated data
+ """
+
+ x = self.interp(
+ x,
+ scale_factor=self.scale_factor,
+ mode=self.mode,
+ align_corners=self.align_corners,
+ )
+
+ # x = self.interp(x, scale_factor=self.scale_factor)
+ # x = self.interp(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=True)
+
+ return x
+
+
+class ResidualConvUnit(nn.Module):
+ """Residual convolution module."""
+
+ def __init__(self, features):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+
+ self.conv1 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
+ )
+
+ self.conv2 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
+ )
+
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: output
+ """
+ out = self.relu(x)
+ out = self.conv1(out)
+ out = self.relu(out)
+ out = self.conv2(out)
+
+ return out + x
+
+
+class FeatureFusionBlock(nn.Module):
+ """Feature fusion block."""
+
+ def __init__(self, features):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super(FeatureFusionBlock, self).__init__()
+
+ self.resConfUnit1 = ResidualConvUnit(features)
+ self.resConfUnit2 = ResidualConvUnit(features)
+
+ def forward(self, *xs):
+ """Forward pass.
+
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+
+ if len(xs) == 2:
+ output += self.resConfUnit1(xs[1])
+
+ output = self.resConfUnit2(output)
+
+ output = nn.functional.interpolate(
+ output, scale_factor=2, mode="bilinear", align_corners=True
+ )
+
+ return output
+
+
+class ResidualConvUnit_custom(nn.Module):
+ """Residual convolution module."""
+
+ def __init__(self, features, activation, bn):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+
+ self.bn = bn
+
+ self.groups = 1
+
+ self.conv1 = nn.Conv2d(
+ features,
+ features,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=not self.bn,
+ groups=self.groups,
+ )
+
+ self.conv2 = nn.Conv2d(
+ features,
+ features,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=not self.bn,
+ groups=self.groups,
+ )
+
+ if self.bn == True:
+ self.bn1 = nn.BatchNorm2d(features)
+ self.bn2 = nn.BatchNorm2d(features)
+
+ self.activation = activation
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: output
+ """
+
+ out = self.activation(x)
+ out = self.conv1(out)
+ if self.bn == True:
+ out = self.bn1(out)
+
+ out = self.activation(out)
+ out = self.conv2(out)
+ if self.bn == True:
+ out = self.bn2(out)
+
+ if self.groups > 1:
+ out = self.conv_merge(out)
+
+ return self.skip_add.add(out, x)
+
+ # return out + x
+
+
+class FeatureFusionBlock_custom(nn.Module):
+ """Feature fusion block."""
+
+ def __init__(
+ self,
+ features,
+ activation,
+ deconv=False,
+ bn=False,
+ expand=False,
+ align_corners=True,
+ ):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super(FeatureFusionBlock_custom, self).__init__()
+
+ self.deconv = deconv
+ self.align_corners = align_corners
+
+ self.groups = 1
+
+ self.expand = expand
+ out_features = features
+ if self.expand == True:
+ out_features = features // 2
+
+ self.out_conv = nn.Conv2d(
+ features,
+ out_features,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=True,
+ groups=1,
+ )
+
+ self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
+ self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, *xs):
+ """Forward pass.
+
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+
+ if len(xs) == 2:
+ res = self.resConfUnit1(xs[1])
+ output = self.skip_add.add(output, res)
+ # output += res
+
+ output = self.resConfUnit2(output)
+
+ output = nn.functional.interpolate(
+ output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
+ )
+
+ output = self.out_conv(output)
+
+ return output
diff --git a/scalelsd/ssl/backbones/dpt/midas_net.py b/scalelsd/ssl/backbones/dpt/midas_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..34d6d7e77b464e7df45b7ab45174a7413d8fbc89
--- /dev/null
+++ b/scalelsd/ssl/backbones/dpt/midas_net.py
@@ -0,0 +1,77 @@
+"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
+This file contains code that is adapted from
+https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
+"""
+import torch
+import torch.nn as nn
+
+from .base_model import BaseModel
+from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
+
+
+class MidasNet_large(BaseModel):
+ """Network for monocular depth estimation."""
+
+ def __init__(self, path=None, features=256, non_negative=True):
+ """Init.
+
+ Args:
+ path (str, optional): Path to saved model. Defaults to None.
+ features (int, optional): Number of features. Defaults to 256.
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
+ """
+ print("Loading weights: ", path)
+
+ super(MidasNet_large, self).__init__()
+
+ use_pretrained = False if path is None else True
+
+ self.pretrained, self.scratch = _make_encoder(
+ backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained
+ )
+
+ self.scratch.refinenet4 = FeatureFusionBlock(features)
+ self.scratch.refinenet3 = FeatureFusionBlock(features)
+ self.scratch.refinenet2 = FeatureFusionBlock(features)
+ self.scratch.refinenet1 = FeatureFusionBlock(features)
+
+ self.scratch.output_conv = nn.Sequential(
+ nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
+ Interpolate(scale_factor=2, mode="bilinear"),
+ nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(True) if non_negative else nn.Identity(),
+ )
+
+ if path:
+ self.load(path)
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input data (image)
+
+ Returns:
+ tensor: depth
+ """
+
+ layer_1 = self.pretrained.layer1(x)
+ layer_2 = self.pretrained.layer2(layer_1)
+ layer_3 = self.pretrained.layer3(layer_2)
+ layer_4 = self.pretrained.layer4(layer_3)
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+ path_4 = self.scratch.refinenet4(layer_4_rn)
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+ out = self.scratch.output_conv(path_1)
+
+ return torch.squeeze(out, dim=1)
diff --git a/scalelsd/ssl/backbones/dpt/models.py b/scalelsd/ssl/backbones/dpt/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..f18d780cce763b4389ec4ac5869933ac90a886ac
--- /dev/null
+++ b/scalelsd/ssl/backbones/dpt/models.py
@@ -0,0 +1,115 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .base_model import BaseModel
+from .blocks import (
+ FeatureFusionBlock,
+ FeatureFusionBlock_custom,
+ Interpolate,
+ _make_encoder,
+ forward_vit,
+)
+from ..multi_task_head import MultitaskHead
+
+
+def _make_fusion_block(features, use_bn):
+ return FeatureFusionBlock_custom(
+ features,
+ nn.ReLU(False),
+ deconv=False,
+ bn=use_bn,
+ expand=False,
+ align_corners=True,
+ )
+
+
+class DPT(BaseModel):
+ def __init__(
+ self,
+ head,
+ features=256,
+ backbone="vitb_rn50_384",
+ readout="project",
+ channels_last=False,
+ use_bn=False,
+ enable_attention_hooks=False,
+ use_layer_scale=False,
+ ):
+
+ super(DPT, self).__init__()
+
+ self.channels_last = channels_last
+
+ hooks = {
+ "vitb_rn50_384": [0, 1, 8, 11],
+ "vitb16_384": [2, 5, 8, 11],
+ "vitl16_384": [5, 11, 17, 23],
+ }
+
+ # Instantiate backbone and reassemble blocks
+ self.pretrained, self.scratch = _make_encoder(
+ backbone,
+ features,
+ False, # Set to true of you want to train from scratch, uses ImageNet weights
+ groups=1,
+ expand=False,
+ exportable=False,
+ hooks=hooks[backbone],
+ use_readout=readout,
+ enable_attention_hooks=enable_attention_hooks,
+ use_layer_scale=use_layer_scale,
+ )
+
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
+
+ self.scratch.output_conv = head
+
+ def forward(self, x):
+ if self.channels_last == True:
+ x.contiguous(memory_format=torch.channels_last)
+
+ layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+ path_4 = self.scratch.refinenet4(layer_4_rn)
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+ out = self.scratch.output_conv(path_1)
+
+ return out
+
+class DPTFieldModel(DPT):
+ def __init__(self, path=None, non_negative=True, head_size=[[3],[1],[1],[2],[2]], **kwargs):
+ features = kwargs["features"] if "features" in kwargs else 256
+
+ kwargs["use_bn"] = True
+
+ num_class = sum(sum(head_size,[]))
+ head = nn.Sequential(
+ nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1),
+ # nn.BatchNorm2d(features//2),
+ nn.ReLU(True),
+ MultitaskHead(features//2, num_class, head_size=head_size),
+ )
+
+ super().__init__(head, **kwargs)
+
+ self.stride = 2
+
+ def forward(self, x):
+ if x.shape[1] == 1:
+ x = torch.cat([x,x,x], dim=1)
+
+ out = super().forward(x)
+ return out, None
+
diff --git a/scalelsd/ssl/backbones/dpt/transforms.py b/scalelsd/ssl/backbones/dpt/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..399adbcdad096ae3fb8a190ecd3ec5483a897251
--- /dev/null
+++ b/scalelsd/ssl/backbones/dpt/transforms.py
@@ -0,0 +1,231 @@
+import numpy as np
+import cv2
+import math
+
+
+def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
+ """Rezise the sample to ensure the given size. Keeps aspect ratio.
+
+ Args:
+ sample (dict): sample
+ size (tuple): image size
+
+ Returns:
+ tuple: new size
+ """
+ shape = list(sample["disparity"].shape)
+
+ if shape[0] >= size[0] and shape[1] >= size[1]:
+ return sample
+
+ scale = [0, 0]
+ scale[0] = size[0] / shape[0]
+ scale[1] = size[1] / shape[1]
+
+ scale = max(scale)
+
+ shape[0] = math.ceil(scale * shape[0])
+ shape[1] = math.ceil(scale * shape[1])
+
+ # resize
+ sample["image"] = cv2.resize(
+ sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
+ )
+
+ sample["disparity"] = cv2.resize(
+ sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
+ )
+ sample["mask"] = cv2.resize(
+ sample["mask"].astype(np.float32),
+ tuple(shape[::-1]),
+ interpolation=cv2.INTER_NEAREST,
+ )
+ sample["mask"] = sample["mask"].astype(bool)
+
+ return tuple(shape)
+
+
+class Resize(object):
+ """Resize sample to given size (width, height)."""
+
+ def __init__(
+ self,
+ width,
+ height,
+ resize_target=True,
+ keep_aspect_ratio=False,
+ ensure_multiple_of=1,
+ resize_method="lower_bound",
+ image_interpolation_method=cv2.INTER_AREA,
+ ):
+ """Init.
+
+ Args:
+ width (int): desired output width
+ height (int): desired output height
+ resize_target (bool, optional):
+ True: Resize the full sample (image, mask, target).
+ False: Resize image only.
+ Defaults to True.
+ keep_aspect_ratio (bool, optional):
+ True: Keep the aspect ratio of the input sample.
+ Output sample might not have the given width and height, and
+ resize behaviour depends on the parameter 'resize_method'.
+ Defaults to False.
+ ensure_multiple_of (int, optional):
+ Output width and height is constrained to be multiple of this parameter.
+ Defaults to 1.
+ resize_method (str, optional):
+ "lower_bound": Output will be at least as large as the given size.
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
+ Defaults to "lower_bound".
+ """
+ self.__width = width
+ self.__height = height
+
+ self.__resize_target = resize_target
+ self.__keep_aspect_ratio = keep_aspect_ratio
+ self.__multiple_of = ensure_multiple_of
+ self.__resize_method = resize_method
+ self.__image_interpolation_method = image_interpolation_method
+
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ if max_val is not None and y > max_val:
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ if y < min_val:
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ return y
+
+ def get_size(self, width, height):
+ # determine new height and width
+ scale_height = self.__height / height
+ scale_width = self.__width / width
+
+ if self.__keep_aspect_ratio:
+ if self.__resize_method == "lower_bound":
+ # scale such that output size is lower bound
+ if scale_width > scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "upper_bound":
+ # scale such that output size is upper bound
+ if scale_width < scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "minimal":
+ # scale as least as possbile
+ if abs(1 - scale_width) < abs(1 - scale_height):
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ else:
+ raise ValueError(
+ f"resize_method {self.__resize_method} not implemented"
+ )
+
+ if self.__resize_method == "lower_bound":
+ new_height = self.constrain_to_multiple_of(
+ scale_height * height, min_val=self.__height
+ )
+ new_width = self.constrain_to_multiple_of(
+ scale_width * width, min_val=self.__width
+ )
+ elif self.__resize_method == "upper_bound":
+ new_height = self.constrain_to_multiple_of(
+ scale_height * height, max_val=self.__height
+ )
+ new_width = self.constrain_to_multiple_of(
+ scale_width * width, max_val=self.__width
+ )
+ elif self.__resize_method == "minimal":
+ new_height = self.constrain_to_multiple_of(scale_height * height)
+ new_width = self.constrain_to_multiple_of(scale_width * width)
+ else:
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
+
+ return (new_width, new_height)
+
+ def __call__(self, sample):
+ width, height = self.get_size(
+ sample["image"].shape[1], sample["image"].shape[0]
+ )
+
+ # resize sample
+ sample["image"] = cv2.resize(
+ sample["image"],
+ (width, height),
+ interpolation=self.__image_interpolation_method,
+ )
+
+ if self.__resize_target:
+ if "disparity" in sample:
+ sample["disparity"] = cv2.resize(
+ sample["disparity"],
+ (width, height),
+ interpolation=cv2.INTER_NEAREST,
+ )
+
+ if "depth" in sample:
+ sample["depth"] = cv2.resize(
+ sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
+ )
+
+ sample["mask"] = cv2.resize(
+ sample["mask"].astype(np.float32),
+ (width, height),
+ interpolation=cv2.INTER_NEAREST,
+ )
+ sample["mask"] = sample["mask"].astype(bool)
+
+ return sample
+
+
+class NormalizeImage(object):
+ """Normlize image by given mean and std."""
+
+ def __init__(self, mean, std):
+ self.__mean = mean
+ self.__std = std
+
+ def __call__(self, sample):
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
+
+ return sample
+
+
+class PrepareForNet(object):
+ """Prepare sample for usage as network input."""
+
+ def __init__(self):
+ pass
+
+ def __call__(self, sample):
+ image = np.transpose(sample["image"], (2, 0, 1))
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
+
+ if "mask" in sample:
+ sample["mask"] = sample["mask"].astype(np.float32)
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
+
+ if "disparity" in sample:
+ disparity = sample["disparity"].astype(np.float32)
+ sample["disparity"] = np.ascontiguousarray(disparity)
+
+ if "depth" in sample:
+ depth = sample["depth"].astype(np.float32)
+ sample["depth"] = np.ascontiguousarray(depth)
+
+ return sample
diff --git a/scalelsd/ssl/backbones/dpt/vit.py b/scalelsd/ssl/backbones/dpt/vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..208324a401c2af33a6c11be530d54c97d50d6458
--- /dev/null
+++ b/scalelsd/ssl/backbones/dpt/vit.py
@@ -0,0 +1,586 @@
+import torch
+import torch.nn as nn
+import timm
+import types
+import math
+import torch.nn.functional as F
+
+
+activations = {}
+
+
+def get_activation(name):
+ def hook(model, input, output):
+ activations[name] = output
+
+ return hook
+
+
+attention = {}
+
+
+def get_attention(name):
+ def hook(module, input, output):
+ x = input[0]
+ B, N, C = x.shape
+ qkv = (
+ module.qkv(x)
+ .reshape(B, N, 3, module.num_heads, C // module.num_heads)
+ .permute(2, 0, 3, 1, 4).contiguous()
+ )
+ q, k, v = (
+ qkv[0],
+ qkv[1],
+ qkv[2],
+ ) # make torchscript happy (cannot use tensor as tuple)
+
+ attn = (q @ k.transpose(-2, -1).contiguous()) * module.scale
+
+ attn = attn.softmax(dim=-1) # [:,:,1,1:]
+ attention[name] = attn
+
+ return hook
+
+
+def get_mean_attention_map(attn, token, shape):
+ attn = attn[:, :, token, 1:]
+ attn = attn.unflatten(2, torch.Size([shape[2] // 16, shape[3] // 16])).float()
+ attn = torch.nn.functional.interpolate(
+ attn, size=shape[2:], mode="bicubic", align_corners=False
+ ).squeeze(0)
+
+ all_attn = torch.mean(attn, 0)
+
+ return all_attn
+
+
+class Slice(nn.Module):
+ def __init__(self, start_index=1):
+ super(Slice, self).__init__()
+ self.start_index = start_index
+
+ def forward(self, x):
+ return x[:, self.start_index :]
+
+
+class AddReadout(nn.Module):
+ def __init__(self, start_index=1):
+ super(AddReadout, self).__init__()
+ self.start_index = start_index
+
+ def forward(self, x):
+ if self.start_index == 2:
+ readout = (x[:, 0] + x[:, 1]) / 2
+ else:
+ readout = x[:, 0]
+ return x[:, self.start_index :] + readout.unsqueeze(1)
+
+
+class ProjectReadout(nn.Module):
+ def __init__(self, in_features, start_index=1):
+ super(ProjectReadout, self).__init__()
+ self.start_index = start_index
+
+ self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
+
+ def forward(self, x):
+ readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
+ features = torch.cat((x[:, self.start_index :], readout), -1)
+
+ return self.project(features)
+
+
+class Transpose(nn.Module):
+ def __init__(self, dim0, dim1):
+ super(Transpose, self).__init__()
+ self.dim0 = dim0
+ self.dim1 = dim1
+
+ def forward(self, x):
+ x = x.transpose(self.dim0, self.dim1).contiguous()
+ return x
+
+
+def forward_vit(pretrained, x):
+ b, c, h, w = x.shape
+
+ glob = pretrained.model.forward_flex(x)
+
+ layer_1 = pretrained.activations["1"]
+ layer_2 = pretrained.activations["2"]
+ layer_3 = pretrained.activations["3"]
+ layer_4 = pretrained.activations["4"]
+
+ layer_1 = pretrained.act_postprocess1[0:2](layer_1)
+ layer_2 = pretrained.act_postprocess2[0:2](layer_2)
+ layer_3 = pretrained.act_postprocess3[0:2](layer_3)
+ layer_4 = pretrained.act_postprocess4[0:2](layer_4)
+
+ unflatten = nn.Sequential(
+ nn.Unflatten(
+ 2,
+ torch.Size(
+ [
+ h // pretrained.model.patch_size[1],
+ w // pretrained.model.patch_size[0],
+ ]
+ ),
+ )
+ )
+
+ if layer_1.ndim == 3:
+ layer_1 = unflatten(layer_1)
+ if layer_2.ndim == 3:
+ layer_2 = unflatten(layer_2)
+ if layer_3.ndim == 3:
+ layer_3 = unflatten(layer_3)
+ if layer_4.ndim == 3:
+ layer_4 = unflatten(layer_4)
+
+ layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
+ layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
+ layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
+ layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
+
+ return layer_1, layer_2, layer_3, layer_4
+
+
+def _resize_pos_embed(self, posemb, gs_h, gs_w):
+ posemb_tok, posemb_grid = (
+ posemb[:, : self.start_index],
+ posemb[0, self.start_index :],
+ )
+
+ gs_old = int(math.sqrt(len(posemb_grid)))
+
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
+ posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
+
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
+
+ return posemb
+
+
+def forward_flex(self, x):
+ b, c, h, w = x.shape
+
+ pos_embed = self._resize_pos_embed(
+ self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
+ )
+
+ B = x.shape[0]
+
+ if hasattr(self.patch_embed, "backbone"):
+ x = self.patch_embed.backbone(x)
+ if isinstance(x, (list, tuple)):
+ x = x[-1] # last feature if backbone outputs list/tuple of features
+
+ x = self.patch_embed.proj(x).flatten(2).transpose(1, 2).contiguous()
+
+ if getattr(self, "dist_token", None) is not None:
+ cls_tokens = self.cls_token.expand(
+ B, -1, -1
+ ) # stole cls_tokens impl from Phil Wang, thanks
+ dist_token = self.dist_token.expand(B, -1, -1)
+ x = torch.cat((cls_tokens, dist_token, x), dim=1)
+ else:
+ cls_tokens = self.cls_token.expand(
+ B, -1, -1
+ ) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ x = x + pos_embed
+ x = self.pos_drop(x)
+
+ for blk in self.blocks:
+ x = blk(x)
+
+ x = self.norm(x)
+
+ return x
+
+
+def get_readout_oper(vit_features, features, use_readout, start_index=1):
+ if use_readout == "ignore":
+ readout_oper = [Slice(start_index)] * len(features)
+ elif use_readout == "add":
+ readout_oper = [AddReadout(start_index)] * len(features)
+ elif use_readout == "project":
+ readout_oper = [
+ ProjectReadout(vit_features, start_index) for out_feat in features
+ ]
+ else:
+ assert (
+ False
+ ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
+
+ return readout_oper
+
+
+def _make_vit_b16_backbone(
+ model,
+ features=[96, 192, 384, 768],
+ size=[384, 384],
+ hooks=[2, 5, 8, 11],
+ vit_features=768,
+ use_readout="ignore",
+ start_index=1,
+ enable_attention_hooks=False,
+):
+ pretrained = nn.Module()
+
+ pretrained.model = model
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
+
+ pretrained.activations = activations
+
+ if enable_attention_hooks:
+ pretrained.model.blocks[hooks[0]].attn.register_forward_hook(
+ get_attention("attn_1")
+ )
+ pretrained.model.blocks[hooks[1]].attn.register_forward_hook(
+ get_attention("attn_2")
+ )
+ pretrained.model.blocks[hooks[2]].attn.register_forward_hook(
+ get_attention("attn_3")
+ )
+ pretrained.model.blocks[hooks[3]].attn.register_forward_hook(
+ get_attention("attn_4")
+ )
+ pretrained.attention = attention
+
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
+
+ # 32, 48, 136, 384
+ pretrained.act_postprocess1 = nn.Sequential(
+ readout_oper[0],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[0],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[0],
+ out_channels=features[0],
+ kernel_size=4,
+ stride=4,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+
+ pretrained.act_postprocess2 = nn.Sequential(
+ readout_oper[1],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[1],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[1],
+ out_channels=features[1],
+ kernel_size=2,
+ stride=2,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+
+ pretrained.act_postprocess3 = nn.Sequential(
+ readout_oper[2],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[2],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ )
+
+ pretrained.act_postprocess4 = nn.Sequential(
+ readout_oper[3],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[3],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.Conv2d(
+ in_channels=features[3],
+ out_channels=features[3],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ ),
+ )
+
+ pretrained.model.start_index = start_index
+ pretrained.model.patch_size = [16, 16]
+
+ # We inject this function into the VisionTransformer instances so that
+ # we can use it with interpolated position embeddings without modifying the library source.
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
+ pretrained.model._resize_pos_embed = types.MethodType(
+ _resize_pos_embed, pretrained.model
+ )
+
+ return pretrained
+
+
+def _make_vit_b_rn50_backbone(
+ model,
+ features=[256, 512, 768, 768],
+ size=[384, 384],
+ hooks=[0, 1, 8, 11],
+ vit_features=768,
+ use_vit_only=False,
+ use_readout="ignore",
+ start_index=1,
+ enable_attention_hooks=False,
+ use_layer_scale=False,
+):
+ pretrained = nn.Module()
+
+ ###
+ if use_layer_scale:
+ from timm.models.vision_transformer import LayerScale
+ for i, block in enumerate (model.blocks) :
+ block.ls1 = LayerScale(vit_features)
+ block.ls2 = LayerScale(vit_features)
+
+ pretrained.model = model
+
+ if use_vit_only == True:
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
+ else:
+ pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
+ get_activation("1")
+ )
+ pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
+ get_activation("2")
+ )
+
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
+
+ if enable_attention_hooks:
+ pretrained.model.blocks[2].attn.register_forward_hook(get_attention("attn_1"))
+ pretrained.model.blocks[5].attn.register_forward_hook(get_attention("attn_2"))
+ pretrained.model.blocks[8].attn.register_forward_hook(get_attention("attn_3"))
+ pretrained.model.blocks[11].attn.register_forward_hook(get_attention("attn_4"))
+ pretrained.attention = attention
+
+ pretrained.activations = activations
+
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
+
+ if use_vit_only == True:
+ pretrained.act_postprocess1 = nn.Sequential(
+ readout_oper[0],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[0],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[0],
+ out_channels=features[0],
+ kernel_size=4,
+ stride=4,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+
+ pretrained.act_postprocess2 = nn.Sequential(
+ readout_oper[1],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[1],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[1],
+ out_channels=features[1],
+ kernel_size=2,
+ stride=2,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+ else:
+ pretrained.act_postprocess1 = nn.Sequential(
+ nn.Identity(), nn.Identity(), nn.Identity()
+ )
+ pretrained.act_postprocess2 = nn.Sequential(
+ nn.Identity(), nn.Identity(), nn.Identity()
+ )
+
+ pretrained.act_postprocess3 = nn.Sequential(
+ readout_oper[2],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[2],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ )
+
+ pretrained.act_postprocess4 = nn.Sequential(
+ readout_oper[3],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[3],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.Conv2d(
+ in_channels=features[3],
+ out_channels=features[3],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ ),
+ )
+
+ pretrained.model.start_index = start_index
+ pretrained.model.patch_size = [16, 16]
+
+ # We inject this function into the VisionTransformer instances so that
+ # we can use it with interpolated position embeddings without modifying the library source.
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
+
+ # We inject this function into the VisionTransformer instances so that
+ # we can use it with interpolated position embeddings without modifying the library source.
+ pretrained.model._resize_pos_embed = types.MethodType(
+ _resize_pos_embed, pretrained.model
+ )
+
+ return pretrained
+
+
+def _make_pretrained_vitb_rn50_384(
+ pretrained,
+ use_readout="ignore",
+ hooks=None,
+ use_vit_only=False,
+ enable_attention_hooks=False,
+ use_layer_scale=False,
+):
+ model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
+
+ hooks = [0, 1, 8, 11] if hooks == None else hooks
+ return _make_vit_b_rn50_backbone(
+ model,
+ features=[256, 512, 768, 768],
+ size=[384, 384],
+ hooks=hooks,
+ use_vit_only=use_vit_only,
+ use_readout=use_readout,
+ enable_attention_hooks=enable_attention_hooks,
+ use_layer_scale=use_layer_scale,
+ )
+
+
+def _make_pretrained_vitl16_384(
+ pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
+):
+ model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
+
+ hooks = [5, 11, 17, 23] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model,
+ features=[256, 512, 1024, 1024],
+ hooks=hooks,
+ vit_features=1024,
+ use_readout=use_readout,
+ enable_attention_hooks=enable_attention_hooks,
+ )
+
+
+def _make_pretrained_vitb16_384(
+ pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
+):
+ model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
+
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model,
+ features=[96, 192, 384, 768],
+ hooks=hooks,
+ use_readout=use_readout,
+ enable_attention_hooks=enable_attention_hooks,
+ )
+
+
+def _make_pretrained_deitb16_384(
+ pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
+):
+ model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
+
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model,
+ features=[96, 192, 384, 768],
+ hooks=hooks,
+ use_readout=use_readout,
+ enable_attention_hooks=enable_attention_hooks,
+ )
+
+
+def _make_pretrained_deitb16_distil_384(
+ pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
+):
+ model = timm.create_model(
+ "vit_deit_base_distilled_patch16_384", pretrained=pretrained
+ )
+
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model,
+ features=[96, 192, 384, 768],
+ hooks=hooks,
+ use_readout=use_readout,
+ start_index=2,
+ enable_attention_hooks=enable_attention_hooks,
+ )
diff --git a/scalelsd/ssl/backbones/multi_task_head.py b/scalelsd/ssl/backbones/multi_task_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..f40a0f15b5cc72e3f140545b3482970b863734fd
--- /dev/null
+++ b/scalelsd/ssl/backbones/multi_task_head.py
@@ -0,0 +1,52 @@
+import torch
+import torch.nn as nn
+class MultitaskHead(nn.Module):
+ def __init__(self, input_channels, num_class, head_size):
+ super(MultitaskHead, self).__init__()
+
+ m = int(input_channels / 4)
+ heads = []
+ for output_channels in sum(head_size, []):
+ heads.append(
+ nn.Sequential(
+ nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(m, output_channels, kernel_size=1),
+ )
+ )
+ self.heads = nn.ModuleList(heads)
+ assert num_class == sum(sum(head_size, []))
+
+ def forward(self, x):
+ # import pdb;pdb.set_trace()
+ return torch.cat([head(x) for head in self.heads], dim=1)
+
+
+class AngleDistanceHead(nn.Module):
+ def __init__(self, input_channels, num_class, head_size):
+ super(AngleDistanceHead, self).__init__()
+
+ m = int(input_channels/4)
+
+ heads = []
+ for output_channels in sum(head_size, []):
+ if output_channels != 2:
+ heads.append(
+ nn.Sequential(
+ nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(m, output_channels, kernel_size=1),
+ )
+ )
+ else:
+ heads.append(
+ nn.Sequential(
+ nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
+ nn.ReLU(inplace=True),
+ CosineSineLayer(m)
+ )
+ )
+ self.heads = nn.ModuleList(heads)
+ assert num_class == sum(sum(head_size, []))
+ def forward(self, x):
+ return torch.cat([head(x) for head in self.heads], dim=1)
\ No newline at end of file
diff --git a/scalelsd/ssl/config/__init__.py b/scalelsd/ssl/config/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00d19d721b1406422a9dc5e341c21d31c5984d9
--- /dev/null
+++ b/scalelsd/ssl/config/__init__.py
@@ -0,0 +1,2 @@
+from .project_config import Config
+from .utils import *
\ No newline at end of file
diff --git a/scalelsd/ssl/config/dataset/hpatches_dataset.yaml b/scalelsd/ssl/config/dataset/hpatches_dataset.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..70933157ecc10ebbb797d838f17bd067fd885c8c
--- /dev/null
+++ b/scalelsd/ssl/config/dataset/hpatches_dataset.yaml
@@ -0,0 +1,105 @@
+### General dataset parameters
+dataset_name: "hpatches"
+add_augmentation_to_all_splits: False
+gray_scale: True
+# Ground truth source ('official' or path to the exported h5 dataset.)
+# gt_source_train: "" # Fill with your own export file
+# gt_source_test: "" # Fill with your own export file
+# Return type: (1) single (to train the detector only)
+# or (2) paired_desc (to train the detector + descriptor)
+return_type: "single"
+random_seed: 0
+
+### Descriptor training parameters
+# Number of points extracted per line
+max_num_samples: 10
+# Max number of training line points extracted in the whole image
+max_pts: 1000
+# Min distance between two points on a line (in pixels)
+min_dist_pts: 10
+# Small jittering of the sampled points during training
+jittering: 0
+
+alteration: "all"
+max_side: 1200
+
+### Data preprocessing configuration
+preprocessing:
+ resize: [512, 512]
+ blur_size: 11
+augmentation:
+ random_scaling:
+ enable: True
+ range: [0.7, 1.5]
+ photometric:
+ enable: true
+ primitives: ['random_brightness', 'random_contrast',
+ 'additive_speckle_noise', 'additive_gaussian_noise',
+ 'additive_shade', 'motion_blur' ]
+ params:
+ random_brightness: {brightness: 0.2}
+ random_contrast: {contrast: [0.3, 1.5]}
+ additive_gaussian_noise: {stddev_range: [0, 10]}
+ additive_speckle_noise: {prob_range: [0, 0.0035]}
+ additive_shade:
+ transparency_range: [-0.5, 0.5]
+ kernel_size_range: [100, 150]
+ motion_blur: {max_kernel_size: 3}
+ random_order: True
+ homographic:
+ enable: true
+ params:
+ translation: true
+ rotation: true
+ scaling: true
+ perspective: true
+ scaling_amplitude: 0.2
+ perspective_amplitude_x: 0.2
+ perspective_amplitude_y: 0.2
+ patch_ratio: 0.85
+ max_angle: 1.57
+ allow_artifacts: true
+ valid_border_margin: 3
+
+## Homography adaptation configuration
+homography_adaptation:
+ num_iter: 10
+ valid_border_margin: 3
+ min_counts: 3
+ homographies:
+ translation: true
+ rotation: true
+ scaling: true
+ perspective: true
+ scaling_amplitude: 0.2
+ perspective_amplitude_x: 0.2
+ perspective_amplitude_y: 0.2
+ allow_artifacts: true
+ patch_ratio: 0.85
+
+data:
+ name: hpatches
+ dataset_dir: HPatches_sequences
+ alteration: all
+ max_side: 1200
+ batch_size: 1
+ num_workers: 4
+model:
+ name: deeplsd
+ tiny: False
+ sharpen: True
+ line_neighborhood: 5
+ loss_weights:
+ df: 1.
+ angle: 1.
+ detect_lines: True
+ multiscale: False
+ scale_factors: [1., 1.5]
+ line_detection_params:
+ grad_nfa: True
+ merge: False
+ optimize: False
+ use_vps: False
+ optimize_vps: False
+ filtering: True
+ grad_thresh: 3
diff --git a/scalelsd/ssl/config/dataset/nyu_dataset.yaml b/scalelsd/ssl/config/dataset/nyu_dataset.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ceae2b4c7e86f2d85cd770ffcc52d14baf3e1152
--- /dev/null
+++ b/scalelsd/ssl/config/dataset/nyu_dataset.yaml
@@ -0,0 +1,77 @@
+### General dataset parameters
+dataset_name: "nyu"
+add_augmentation_to_all_splits: False
+gray_scale: True
+# Ground truth source ('official' or path to the exported h5 dataset.)
+# gt_source_train: "" # Fill with your own export file
+# gt_source_test: "" # Fill with your own export file
+# Return type: (1) single (to train the detector only)
+# or (2) paired_desc (to train the detector + descriptor)
+return_type: "single"
+random_seed: 0
+
+val_size: 49
+
+### Descriptor training parameters
+# Number of points extracted per line
+max_num_samples: 10
+# Max number of training line points extracted in the whole image
+max_pts: 1000
+# Min distance between two points on a line (in pixels)
+min_dist_pts: 10
+# Small jittering of the sampled points during training
+jittering: 0
+
+### Data preprocessing configuration
+preprocessing:
+ resize: [512, 512]
+ blur_size: 11
+augmentation:
+ random_scaling:
+ enable: True
+ range: [0.7, 1.5]
+ photometric:
+ enable: true
+ primitives: ['random_brightness', 'random_contrast',
+ 'additive_speckle_noise', 'additive_gaussian_noise',
+ 'additive_shade', 'motion_blur' ]
+ params:
+ random_brightness: {brightness: 0.2}
+ random_contrast: {contrast: [0.3, 1.5]}
+ additive_gaussian_noise: {stddev_range: [0, 10]}
+ additive_speckle_noise: {prob_range: [0, 0.0035]}
+ additive_shade:
+ transparency_range: [-0.5, 0.5]
+ kernel_size_range: [100, 150]
+ motion_blur: {max_kernel_size: 3}
+ random_order: True
+ homographic:
+ enable: true
+ params:
+ translation: true
+ rotation: true
+ scaling: true
+ perspective: true
+ scaling_amplitude: 0.2
+ perspective_amplitude_x: 0.2
+ perspective_amplitude_y: 0.2
+ patch_ratio: 0.85
+ max_angle: 1.57
+ allow_artifacts: true
+ valid_border_margin: 3
+
+## Homography adaptation configuration
+homography_adaptation:
+ num_iter: 10
+ valid_border_margin: 3
+ min_counts: 3
+ homographies:
+ translation: true
+ rotation: true
+ scaling: true
+ perspective: true
+ scaling_amplitude: 0.2
+ perspective_amplitude_x: 0.2
+ perspective_amplitude_y: 0.2
+ allow_artifacts: true
+ patch_ratio: 0.85
diff --git a/scalelsd/ssl/config/dataset/official_yorkurban_dataset.yaml b/scalelsd/ssl/config/dataset/official_yorkurban_dataset.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d4cfe27c50ac1f471fadec2806ec62ec40103a1c
--- /dev/null
+++ b/scalelsd/ssl/config/dataset/official_yorkurban_dataset.yaml
@@ -0,0 +1,75 @@
+### General dataset parameters
+dataset_name: "official_yorkurban"
+add_augmentation_to_all_splits: False
+gray_scale: True
+# Ground truth source ('official' or path to the exported h5 dataset.)
+# gt_source_train: "" # Fill with your own export file
+# gt_source_test: "" # Fill with your own export file
+# Return type: (1) single (to train the detector only)
+# or (2) paired_desc (to train the detector + descriptor)
+return_type: "single"
+random_seed: 0
+
+### Descriptor training parameters
+# Number of points extracted per line
+max_num_samples: 10
+# Max number of training line points extracted in the whole image
+max_pts: 1000
+# Min distance between two points on a line (in pixels)
+min_dist_pts: 10
+# Small jittering of the sampled points during training
+jittering: 0
+
+### Data preprocessing configuration
+preprocessing:
+ resize: [512, 512]
+ blur_size: 11
+augmentation:
+ random_scaling:
+ enable: True
+ range: [0.7, 1.5]
+ photometric:
+ enable: true
+ primitives: ['random_brightness', 'random_contrast',
+ 'additive_speckle_noise', 'additive_gaussian_noise',
+ 'additive_shade', 'motion_blur' ]
+ params:
+ random_brightness: {brightness: 0.2}
+ random_contrast: {contrast: [0.3, 1.5]}
+ additive_gaussian_noise: {stddev_range: [0, 10]}
+ additive_speckle_noise: {prob_range: [0, 0.0035]}
+ additive_shade:
+ transparency_range: [-0.5, 0.5]
+ kernel_size_range: [100, 150]
+ motion_blur: {max_kernel_size: 3}
+ random_order: True
+ homographic:
+ enable: true
+ params:
+ translation: true
+ rotation: true
+ scaling: true
+ perspective: true
+ scaling_amplitude: 0.2
+ perspective_amplitude_x: 0.2
+ perspective_amplitude_y: 0.2
+ patch_ratio: 0.85
+ max_angle: 1.57
+ allow_artifacts: true
+ valid_border_margin: 3
+
+## Homography adaptation configuration
+homography_adaptation:
+ num_iter: 10
+ valid_border_margin: 3
+ min_counts: 3
+ homographies:
+ translation: true
+ rotation: true
+ scaling: true
+ perspective: true
+ scaling_amplitude: 0.2
+ perspective_amplitude_x: 0.2
+ perspective_amplitude_y: 0.2
+ allow_artifacts: true
+ patch_ratio: 0.85
diff --git a/scalelsd/ssl/config/dataset/rdnim_dataset.yaml b/scalelsd/ssl/config/dataset/rdnim_dataset.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0524bb0ccfb0d4f130d045c0be9c03a5b8667c8c
--- /dev/null
+++ b/scalelsd/ssl/config/dataset/rdnim_dataset.yaml
@@ -0,0 +1,77 @@
+### General dataset parameters
+dataset_name: "rdnim"
+add_augmentation_to_all_splits: False
+gray_scale: True
+# Ground truth source ('official' or path to the exported h5 dataset.)
+# gt_source_train: "" # Fill with your own export file
+# gt_source_test: "" # Fill with your own export file
+# Return type: (1) single (to train the detector only)
+# or (2) paired_desc (to train the detector + descriptor)
+return_type: "single"
+random_seed: 0
+
+### Descriptor training parameters
+# Number of points extracted per line
+max_num_samples: 10
+# Max number of training line points extracted in the whole image
+max_pts: 1000
+# Min distance between two points on a line (in pixels)
+min_dist_pts: 10
+# Small jittering of the sampled points during training
+jittering: 0
+
+reference: "night"
+
+### Data preprocessing configuration
+preprocessing:
+ resize: [512, 512]
+ blur_size: 11
+augmentation:
+ random_scaling:
+ enable: True
+ range: [0.7, 1.5]
+ photometric:
+ enable: true
+ primitives: ['random_brightness', 'random_contrast',
+ 'additive_speckle_noise', 'additive_gaussian_noise',
+ 'additive_shade', 'motion_blur' ]
+ params:
+ random_brightness: {brightness: 0.2}
+ random_contrast: {contrast: [0.3, 1.5]}
+ additive_gaussian_noise: {stddev_range: [0, 10]}
+ additive_speckle_noise: {prob_range: [0, 0.0035]}
+ additive_shade:
+ transparency_range: [-0.5, 0.5]
+ kernel_size_range: [100, 150]
+ motion_blur: {max_kernel_size: 3}
+ random_order: True
+ homographic:
+ enable: true
+ params:
+ translation: true
+ rotation: true
+ scaling: true
+ perspective: true
+ scaling_amplitude: 0.2
+ perspective_amplitude_x: 0.2
+ perspective_amplitude_y: 0.2
+ patch_ratio: 0.85
+ max_angle: 1.57
+ allow_artifacts: true
+ valid_border_margin: 3
+
+## Homography adaptation configuration
+homography_adaptation:
+ num_iter: 10
+ valid_border_margin: 3
+ min_counts: 3
+ homographies:
+ translation: true
+ rotation: true
+ scaling: true
+ perspective: true
+ scaling_amplitude: 0.2
+ perspective_amplitude_x: 0.2
+ perspective_amplitude_y: 0.2
+ allow_artifacts: true
+ patch_ratio: 0.85
diff --git a/scalelsd/ssl/config/dataset/synthetic_dataset-1024.yaml b/scalelsd/ssl/config/dataset/synthetic_dataset-1024.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..fd79820d57c60bb7356054ad91b51b7b2db57464
--- /dev/null
+++ b/scalelsd/ssl/config/dataset/synthetic_dataset-1024.yaml
@@ -0,0 +1,49 @@
+### General dataset parameters
+dataset_name: "synthetic_shape"
+primitives: "all"
+add_augmentation_to_all_splits: True
+test_augmentation_seed: 200
+# Shape generation configuration
+generation:
+ # split_sizes: {'train': 20000, 'val': 2000, 'test': 400}
+ split_sizes: {'train': 2000, 'val': 2000, 'test': 400}
+ random_seed: 10
+ image_size: [960, 1280]
+ min_len: 0.0985
+ min_label_len: 0.099
+ params:
+ generate_background:
+ min_kernel_size: 150
+ max_kernel_size: 500
+ min_rad_ratio: 0.02
+ max_rad_ratio: 0.031
+ draw_stripes:
+ transform_params: [0.1, 0.1]
+ draw_multiple_polygons:
+ kernel_boundaries: [50, 100]
+
+### Data preprocessing configuration.
+preprocessing:
+ resize: [1024, 1024]
+ blur_size: 11
+augmentation:
+ photometric:
+ enable: True
+ primitives: 'all'
+ params: {}
+ random_order: True
+ homographic:
+ enable: True
+ params:
+ translation: true
+ rotation: true
+ scaling: true
+ perspective: true
+ scaling_amplitude: 0.2
+ perspective_amplitude_x: 0.2
+ perspective_amplitude_y: 0.2
+ patch_ratio: 0.8
+ max_angle: 1.57
+ allow_artifacts: true
+ translation_overflow: 0.05
+ valid_border_margin: 0
diff --git a/scalelsd/ssl/config/dataset/synthetic_dataset-2k.yaml b/scalelsd/ssl/config/dataset/synthetic_dataset-2k.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f760ae66077247c9e6c355decf0f277e425561b4
--- /dev/null
+++ b/scalelsd/ssl/config/dataset/synthetic_dataset-2k.yaml
@@ -0,0 +1,50 @@
+### General dataset parameters
+dataset_name: "synthetic_shape"
+primitives: "all"
+add_augmentation_to_all_splits: True
+test_augmentation_seed: 200
+alias: 2k
+# Shape generation configuration
+generation:
+ # split_sizes: {'train': 20000, 'val': 2000, 'test': 400}
+ split_sizes: {'train': 2000, 'val': 200, 'test': 400}
+ random_seed: 10
+ image_size: [960, 1280]
+ min_len: 0.0985
+ min_label_len: 0.099
+ params:
+ generate_background:
+ min_kernel_size: 150
+ max_kernel_size: 500
+ min_rad_ratio: 0.02
+ max_rad_ratio: 0.031
+ draw_stripes:
+ transform_params: [0.1, 0.1]
+ draw_multiple_polygons:
+ kernel_boundaries: [50, 100]
+
+### Data preprocessing configuration.
+preprocessing:
+ resize: [512, 512]
+ blur_size: 11
+augmentation:
+ photometric:
+ enable: True
+ primitives: 'all'
+ params: {}
+ random_order: True
+ homographic:
+ enable: True
+ params:
+ translation: true
+ rotation: true
+ scaling: true
+ perspective: true
+ scaling_amplitude: 0.2
+ perspective_amplitude_x: 0.2
+ perspective_amplitude_y: 0.2
+ patch_ratio: 0.8
+ max_angle: 1.57
+ allow_artifacts: true
+ translation_overflow: 0.05
+ valid_border_margin: 0
diff --git a/scalelsd/ssl/config/dataset/synthetic_dataset-4k.yaml b/scalelsd/ssl/config/dataset/synthetic_dataset-4k.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7d1dd99cc778b6211d4d3251fb0f12d8e4bf9daf
--- /dev/null
+++ b/scalelsd/ssl/config/dataset/synthetic_dataset-4k.yaml
@@ -0,0 +1,50 @@
+### General dataset parameters
+dataset_name: "synthetic_shape"
+primitives: "all"
+add_augmentation_to_all_splits: True
+test_augmentation_seed: 200
+alias: 4k
+# Shape generation configuration
+generation:
+ # split_sizes: {'train': 20000, 'val': 2000, 'test': 400}
+ split_sizes: {'train': 4000, 'val': 2000, 'test': 400}
+ random_seed: 10
+ image_size: [960, 1280]
+ min_len: 0.0985
+ min_label_len: 0.099
+ params:
+ generate_background:
+ min_kernel_size: 150
+ max_kernel_size: 500
+ min_rad_ratio: 0.02
+ max_rad_ratio: 0.031
+ draw_stripes:
+ transform_params: [0.1, 0.1]
+ draw_multiple_polygons:
+ kernel_boundaries: [50, 100]
+
+### Data preprocessing configuration.
+preprocessing:
+ resize: [512, 512]
+ blur_size: 11
+augmentation:
+ photometric:
+ enable: True
+ primitives: 'all'
+ params: {}
+ random_order: True
+ homographic:
+ enable: True
+ params:
+ translation: true
+ rotation: true
+ scaling: true
+ perspective: true
+ scaling_amplitude: 0.2
+ perspective_amplitude_x: 0.2
+ perspective_amplitude_y: 0.2
+ patch_ratio: 0.8
+ max_angle: 1.57
+ allow_artifacts: true
+ translation_overflow: 0.05
+ valid_border_margin: 0
diff --git a/scalelsd/ssl/config/dataset/synthetic_dataset-large.yaml b/scalelsd/ssl/config/dataset/synthetic_dataset-large.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..844f44af29b335e0b4275d0feba792a77a522513
--- /dev/null
+++ b/scalelsd/ssl/config/dataset/synthetic_dataset-large.yaml
@@ -0,0 +1,50 @@
+### General dataset parameters
+dataset_name: "synthetic_shape"
+primitives: "all"
+add_augmentation_to_all_splits: True
+test_augmentation_seed: 200
+alias: "synthetic_shape_large"
+# Shape generation configuration
+generation:
+ split_sizes: {'train': 20000, 'val': 2000, 'test': 400}
+ # split_sizes: {'train': 2000, 'val': 2000, 'test': 400}
+ random_seed: 10
+ image_size: [960, 1280]
+ min_len: 0.0985
+ min_label_len: 0.099
+ params:
+ generate_background:
+ min_kernel_size: 150
+ max_kernel_size: 500
+ min_rad_ratio: 0.02
+ max_rad_ratio: 0.031
+ draw_stripes:
+ transform_params: [0.1, 0.1]
+ draw_multiple_polygons:
+ kernel_boundaries: [50, 100]
+
+### Data preprocessing configuration.
+preprocessing:
+ resize: [512, 512]
+ blur_size: 11
+augmentation:
+ photometric:
+ enable: True
+ primitives: 'all'
+ params: {}
+ random_order: True
+ homographic:
+ enable: True
+ params:
+ translation: true
+ rotation: true
+ scaling: true
+ perspective: true
+ scaling_amplitude: 0.2
+ perspective_amplitude_x: 0.2
+ perspective_amplitude_y: 0.2
+ patch_ratio: 0.8
+ max_angle: 1.57
+ allow_artifacts: true
+ translation_overflow: 0.05
+ valid_border_margin: 0
diff --git a/scalelsd/ssl/config/dataset/synthetic_dataset.yaml b/scalelsd/ssl/config/dataset/synthetic_dataset.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a91258f956a7ee68f3cd796005eb1d5c30c9ae83
--- /dev/null
+++ b/scalelsd/ssl/config/dataset/synthetic_dataset.yaml
@@ -0,0 +1,51 @@
+### General dataset parameters
+dataset_name: "synthetic_shape"
+primitives: "all"
+add_augmentation_to_all_splits: True
+test_augmentation_seed: 200
+# Shape generation configuration
+generation:
+ # split_sizes: {'train': 20000, 'val': 2000, 'test': 400}
+ # split_sizes: {'train': 2000, 'val': 2000, 'test': 400}
+ split_sizes: {'train': 100, 'val': 100, 'test': 100}
+ random_seed: 10
+ # image_size: [960, 1280]
+ image_size: [1024, 1024]
+ min_len: 0.0985
+ min_label_len: 0.099
+ params:
+ generate_background:
+ min_kernel_size: 150
+ max_kernel_size: 500
+ min_rad_ratio: 0.02
+ max_rad_ratio: 0.031
+ draw_stripes:
+ transform_params: [0.1, 0.1]
+ draw_multiple_polygons:
+ kernel_boundaries: [50, 100]
+
+### Data preprocessing configuration.
+preprocessing:
+ resize: [512, 512]
+ blur_size: 11
+augmentation:
+ photometric:
+ enable: True
+ primitives: 'all'
+ params: {}
+ random_order: True
+ homographic:
+ enable: True
+ params:
+ translation: true
+ rotation: true
+ scaling: true
+ perspective: true
+ scaling_amplitude: 0.2
+ perspective_amplitude_x: 0.2
+ perspective_amplitude_y: 0.2
+ patch_ratio: 0.8
+ max_angle: 1.57
+ allow_artifacts: true
+ translation_overflow: 0.05
+ valid_border_margin: 0
diff --git a/scalelsd/ssl/config/dataset/wireframe_official_gt copy.yaml b/scalelsd/ssl/config/dataset/wireframe_official_gt copy.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f6a7a4194e23f3717bbc3b3a2bf707891f2e9abb
--- /dev/null
+++ b/scalelsd/ssl/config/dataset/wireframe_official_gt copy.yaml
@@ -0,0 +1,86 @@
+dataset_name: "wireframe"
+add_augmentation_to_all_splits: False
+gray_scale: True
+# return_type: "paired_desc"
+random_seed: 0
+# Ground truth source (official or path to the epxorted h5 dataset.)
+gt_source_train: "official"
+gt_source_test: "official"
+# Date preprocessing configuration.
+preprocessing:
+ resize: [512, 512]
+ blur_size: 11
+augmentation:
+ random_scaling:
+ enable: True
+ range: [0.7, 1.5]
+ photometric:
+ enable: true
+ primitives: ['random_brightness', 'random_contrast',
+ 'additive_speckle_noise', 'additive_gaussian_noise',
+ 'additive_shade', 'motion_blur' ]
+ params:
+ random_brightness: {brightness: 0.2}
+ random_contrast: {contrast: [0.3, 1.5]}
+ additive_gaussian_noise: {stddev_range: [0, 10]}
+ additive_speckle_noise: {prob_range: [0, 0.0035]}
+ additive_shade:
+ transparency_range: [-0.5, 0.5]
+ kernel_size_range: [100, 150]
+ motion_blur: {max_kernel_size: 3}
+ random_order: True
+ homographic:
+ enable: true
+ params:
+ translation: true
+ rotation: true
+ scaling: true
+ perspective: true
+ scaling_amplitude: 0.2
+ perspective_amplitude_x: 0.2
+ perspective_amplitude_y: 0.2
+ patch_ratio: 0.85
+ max_angle: 1.57
+ allow_artifacts: true
+ valid_border_margin: 3
+# The homography adaptation configuration
+homography_adaptation:
+ num_iter: 100
+ aggregation: 'sum'
+ mode: 'ver1'
+ valid_border_margin: 3
+ min_counts: 30
+ homographies:
+ translation: true
+ rotation: true
+ scaling: true
+ perspective: true
+ scaling_amplitude: 0.2
+ perspective_amplitude_x: 0.2
+ perspective_amplitude_y: 0.2
+ allow_artifacts: true
+ patch_ratio: 0.85
+# Evaluation related config
+evaluation:
+ repeatability:
+ # Initial random seed used to sample homographic augmentation
+ seed: 200
+ # Parameter used to sample illumination change evaluation set.
+ photometric:
+ enable: False
+ # Parameter used to sample viewpoint change evaluation set.
+ homographic:
+ enable: True
+ num_samples: 2
+ params:
+ translation: true
+ rotation: true
+ scaling: true
+ perspective: true
+ scaling_amplitude: 0.2
+ perspective_amplitude_x: 0.2
+ perspective_amplitude_y: 0.2
+ patch_ratio: 0.85
+ max_angle: 1.57
+ allow_artifacts: true
+ valid_border_margin: 3
\ No newline at end of file
diff --git a/scalelsd/ssl/config/dataset/wireframe_official_gt.yaml b/scalelsd/ssl/config/dataset/wireframe_official_gt.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f6a7a4194e23f3717bbc3b3a2bf707891f2e9abb
--- /dev/null
+++ b/scalelsd/ssl/config/dataset/wireframe_official_gt.yaml
@@ -0,0 +1,86 @@
+dataset_name: "wireframe"
+add_augmentation_to_all_splits: False
+gray_scale: True
+# return_type: "paired_desc"
+random_seed: 0
+# Ground truth source (official or path to the epxorted h5 dataset.)
+gt_source_train: "official"
+gt_source_test: "official"
+# Date preprocessing configuration.
+preprocessing:
+ resize: [512, 512]
+ blur_size: 11
+augmentation:
+ random_scaling:
+ enable: True
+ range: [0.7, 1.5]
+ photometric:
+ enable: true
+ primitives: ['random_brightness', 'random_contrast',
+ 'additive_speckle_noise', 'additive_gaussian_noise',
+ 'additive_shade', 'motion_blur' ]
+ params:
+ random_brightness: {brightness: 0.2}
+ random_contrast: {contrast: [0.3, 1.5]}
+ additive_gaussian_noise: {stddev_range: [0, 10]}
+ additive_speckle_noise: {prob_range: [0, 0.0035]}
+ additive_shade:
+ transparency_range: [-0.5, 0.5]
+ kernel_size_range: [100, 150]
+ motion_blur: {max_kernel_size: 3}
+ random_order: True
+ homographic:
+ enable: true
+ params:
+ translation: true
+ rotation: true
+ scaling: true
+ perspective: true
+ scaling_amplitude: 0.2
+ perspective_amplitude_x: 0.2
+ perspective_amplitude_y: 0.2
+ patch_ratio: 0.85
+ max_angle: 1.57
+ allow_artifacts: true
+ valid_border_margin: 3
+# The homography adaptation configuration
+homography_adaptation:
+ num_iter: 100
+ aggregation: 'sum'
+ mode: 'ver1'
+ valid_border_margin: 3
+ min_counts: 30
+ homographies:
+ translation: true
+ rotation: true
+ scaling: true
+ perspective: true
+ scaling_amplitude: 0.2
+ perspective_amplitude_x: 0.2
+ perspective_amplitude_y: 0.2
+ allow_artifacts: true
+ patch_ratio: 0.85
+# Evaluation related config
+evaluation:
+ repeatability:
+ # Initial random seed used to sample homographic augmentation
+ seed: 200
+ # Parameter used to sample illumination change evaluation set.
+ photometric:
+ enable: False
+ # Parameter used to sample viewpoint change evaluation set.
+ homographic:
+ enable: True
+ num_samples: 2
+ params:
+ translation: true
+ rotation: true
+ scaling: true
+ perspective: true
+ scaling_amplitude: 0.2
+ perspective_amplitude_x: 0.2
+ perspective_amplitude_y: 0.2
+ patch_ratio: 0.85
+ max_angle: 1.57
+ allow_artifacts: true
+ valid_border_margin: 3
\ No newline at end of file
diff --git a/scalelsd/ssl/config/dataset/yorkurban_dataset.yaml b/scalelsd/ssl/config/dataset/yorkurban_dataset.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..af75ec177e59492368486a761bd4f34427e15e20
--- /dev/null
+++ b/scalelsd/ssl/config/dataset/yorkurban_dataset.yaml
@@ -0,0 +1,99 @@
+### General dataset parameters
+dataset_name: "yorkurban"
+add_augmentation_to_all_splits: False
+gray_scale: True
+# Ground truth source ('official' or path to the exported h5 dataset.)
+# gt_source_train: "" # Fill with your own export file
+# gt_source_test: "" # Fill with your own export file
+# Return type: (1) single (to train the detector only)
+# or (2) paired_desc (to train the detector + descriptor)
+return_type: "single"
+random_seed: 0
+
+### Descriptor training parameters
+# Number of points extracted per line
+max_num_samples: 10
+# Max number of training line points extracted in the whole image
+max_pts: 1000
+# Min distance between two points on a line (in pixels)
+min_dist_pts: 10
+# Small jittering of the sampled points during training
+jittering: 0
+
+### Data preprocessing configuration
+preprocessing:
+ resize: [512, 512]
+ blur_size: 11
+augmentation:
+ random_scaling:
+ enable: True
+ range: [0.7, 1.5]
+ photometric:
+ enable: true
+ primitives: ['random_brightness', 'random_contrast',
+ 'additive_speckle_noise', 'additive_gaussian_noise',
+ 'additive_shade', 'motion_blur' ]
+ params:
+ random_brightness: {brightness: 0.2}
+ random_contrast: {contrast: [0.3, 1.5]}
+ additive_gaussian_noise: {stddev_range: [0, 10]}
+ additive_speckle_noise: {prob_range: [0, 0.0035]}
+ additive_shade:
+ transparency_range: [-0.5, 0.5]
+ kernel_size_range: [100, 150]
+ motion_blur: {max_kernel_size: 3}
+ random_order: True
+ homographic:
+ enable: true
+ params:
+ translation: true
+ rotation: true
+ scaling: true
+ perspective: true
+ scaling_amplitude: 0.2
+ perspective_amplitude_x: 0.2
+ perspective_amplitude_y: 0.2
+ patch_ratio: 0.85
+ max_angle: 1.57
+ allow_artifacts: true
+ valid_border_margin: 3
+
+## Homography adaptation configuration
+homography_adaptation:
+ num_iter: 10
+ valid_border_margin: 3
+ min_counts: 3
+ homographies:
+ translation: false
+ rotation: false
+ scaling: true
+ perspective: false
+ scaling_amplitude: -1
+ perspective_amplitude_x: 0.2
+ perspective_amplitude_y: 0.2
+ allow_artifacts: true
+ patch_ratio: 0.85
+# Evaluation related config
+evaluation:
+ repeatability:
+ # Initial random seed used to sample homographic augmentation
+ seed: 200
+ # Parameter used to sample illumination change evaluation set.
+ photometric:
+ enable: False
+ # Parameter used to sample viewpoint change evaluation set.
+ homographic:
+ enable: True
+ num_samples: 2
+ params:
+ translation: true
+ rotation: true
+ scaling: true
+ perspective: true
+ scaling_amplitude: -1
+ perspective_amplitude_x: 0.2
+ perspective_amplitude_y: 0.2
+ patch_ratio: 0.85
+ max_angle: 1.57
+ allow_artifacts: true
+ valid_border_margin: 3
diff --git a/scalelsd/ssl/config/project_config.py b/scalelsd/ssl/config/project_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..1070824a584e38e52941aca2604fbc80cf7fca2f
--- /dev/null
+++ b/scalelsd/ssl/config/project_config.py
@@ -0,0 +1,70 @@
+"""
+Project configurations.
+"""
+import os
+
+
+class Config(object):
+ """ Datasets and experiments folders for the whole project. """
+ #####################
+ ## Dataset setting ##
+ #####################
+ default_dataroot = os.path.join(
+ os.path.dirname(__file__),
+ '..','..','..','data-ssl'
+ )
+ default_dataroot = os.path.abspath(default_dataroot)
+ default_exproot = os.path.join(
+ os.path.dirname(__file__),
+ '..','..','..','exp-ssl'
+ )
+ default_exproot = os.path.abspath(default_exproot)
+
+ DATASET_ROOT = os.getenv("DATASET_ROOT", default_dataroot) # TODO: path to your datasets folder
+ if not os.path.exists(DATASET_ROOT):
+ os.makedirs(DATASET_ROOT, exist_ok=True)
+
+ # Synthetic shape dataset
+ synthetic_dataroot = os.path.join(DATASET_ROOT, "synthetic_shapes")
+ synthetic_cache_path = os.path.join(DATASET_ROOT, "synthetic_shapes")
+ if not os.path.exists(synthetic_dataroot):
+ os.makedirs(synthetic_dataroot, exist_ok=True)
+
+ EXPORT_ROOT = os.getenv("EXPORT_ROOT", default_dataroot) # TODO: path to your datasets folder
+
+ # Exported predictions dataset
+ export_dataroot = os.path.join(EXPORT_ROOT, "export_datasets")
+ export_cache_path = os.path.join(EXPORT_ROOT, "export_datasets")
+ if not os.path.exists(export_dataroot):
+ os.makedirs(export_dataroot, exist_ok=True)
+
+ # York Urban dataset
+ yorkurban_dataroot = os.path.join(DATASET_ROOT, "YorkUrbanDB")
+ yorkurban_cache_path = os.path.join(DATASET_ROOT, "YorkUrbanDB")
+
+ # Wireframe dataset
+ wireframe_dataroot = os.path.join(DATASET_ROOT, "wireframe")
+ wireframe_cache_path = os.path.join(DATASET_ROOT, "wireframe")
+
+ # Holicity dataset
+ holicity_dataroot = os.path.join(DATASET_ROOT, "Holicity")
+ holicity_cache_path = os.path.join(DATASET_ROOT, "Holicity")
+
+ # Official York Urban dataset
+ official_yorkurban_dataroot = os.path.join(DATASET_ROOT, "off_YorkUrbanDB")
+ official_yorkurban_cache_path = os.path.join(DATASET_ROOT, "off_YorkUrbanDB")
+
+ # NYU_depth_v2
+ nyu_dataroot = os.path.join(DATASET_ROOT, "NYU_depth_v2")
+ nyu_dataroot_cache_path = os.path.join(DATASET_ROOT, "NYU_depth_v2")
+
+ rdnim_dataroot = os.path.join(DATASET_ROOT, "RDNIM")
+ hpatches_dataroot = os.path.join(DATASET_ROOT, "HPatches_sequences")
+
+ ########################
+ ## Experiment Setting ##
+ ########################
+ EXP_PATH = os.getenv("EXP_PATH", default_exproot) # TODO: path to your experiments folder
+
+ if not os.path.exists(EXP_PATH):
+ os.makedirs(EXP_PATH, exist_ok=True)
diff --git a/scalelsd/ssl/config/utils.py b/scalelsd/ssl/config/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e5fba5079c0c44a2498127c6bc412d63f08d724
--- /dev/null
+++ b/scalelsd/ssl/config/utils.py
@@ -0,0 +1,50 @@
+import yaml
+import os
+from easydict import EasyDict
+
+def load_config(config_path):
+ """ Load configurations from a given yaml file. """
+ # Check file exists
+ if not os.path.exists(config_path):
+ raise ValueError("[Error] The provided config path is not valid.")
+
+ # Load the configuration
+ with open(config_path, "r") as f:
+ config = yaml.safe_load(f)
+
+ # return EasyDict(config)
+ return config
+
+def update_config(path, model_cfg=None, dataset_cfg=None):
+ """ Update configuration file from the resume path. """
+ # Check we need to update or completely override.
+ model_cfg = {} if model_cfg is None else model_cfg
+ dataset_cfg = {} if dataset_cfg is None else dataset_cfg
+
+ # Load saved configs
+ with open(os.path.join(path, "model_cfg.yaml"), "r") as f:
+ model_cfg_saved = yaml.safe_load(f)
+ model_cfg.update(model_cfg_saved)
+ with open(os.path.join(path, "dataset_cfg.yaml"), "r") as f:
+ dataset_cfg_saved = yaml.safe_load(f)
+ dataset_cfg.update(dataset_cfg_saved)
+
+ # Update the saved yaml file
+ if not model_cfg == model_cfg_saved:
+ with open(os.path.join(path, "model_cfg.yaml"), "w") as f:
+ yaml.dump(model_cfg, f)
+ if not dataset_cfg == dataset_cfg_saved:
+ with open(os.path.join(path, "dataset_cfg.yaml"), "w") as f:
+ yaml.dump(dataset_cfg, f)
+
+ return model_cfg, dataset_cfg
+
+def record_config(model_cfg, dataset_cfg, output_path):
+ """ Record dataset config to the log path. """
+ # Record model config
+ with open(os.path.join(output_path, "model_cfg.yaml"), "w") as f:
+ yaml.safe_dump(model_cfg, f)
+
+ # Record dataset config
+ with open(os.path.join(output_path, "dataset_cfg.yaml"), "w") as f:
+ yaml.safe_dump(dataset_cfg, f)
\ No newline at end of file
diff --git a/scalelsd/ssl/datasets/__init__.py b/scalelsd/ssl/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/scalelsd/ssl/datasets/dataset_eval.py b/scalelsd/ssl/datasets/dataset_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..2dfd2fb9940ae6480472007eb2d11a44e555582a
--- /dev/null
+++ b/scalelsd/ssl/datasets/dataset_eval.py
@@ -0,0 +1,87 @@
+from pathlib import Path
+import cv2
+import PIL
+import numpy as np
+import torch
+import torch.utils
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+import glob
+
+from .transforms.homographic_transforms import sample_homography
+from kornia.geometry import warp_perspective,transform_points
+
+
+homography_params = {
+ 'translation': True,
+ 'rotation': True,
+ 'scaling': True,
+ 'perspective': True,
+ 'scaling_amplitude': 0.2,
+ 'perspective_amplitude_x': 0.2,
+ 'perspective_amplitude_y': 0.2,
+ 'patch_ratio': 0.85,
+ 'max_angle': 1.57,
+ 'allow_artifacts': True
+}
+
+class Hybrid_Dataset(torch.utils.data.Dataset):
+ def __init__(self, datacfg=None, images_root=None, overwrite=False):
+ self.conf = datacfg
+ self.root = images_root
+
+ # torch.manual_seed(self.conf.seed)
+ # np.random.seed(self.conf.seed)
+
+ # # Extract images paths
+ # self.files = [Path(self.root)/img for img in Path(self.root).iterdir()
+ # if img.with_suffix('.png') or img.with_suffix('.jpg')]
+ self.files = glob.glob(f'{images_root}/*.png') + glob.glob(f'{images_root}/*.jpg')
+ self.files.sort()
+
+ self.npz_files = [] if overwrite else glob.glob(f'{images_root}/*.npz')
+
+ self.size = (512, 512)
+
+ self.overwrite = overwrite
+
+ if len(self.files) == 0:
+ raise ValueError(f'Could not find any images in the path of {self.root}. Please check the input images root path.')
+
+ # Randomly generate the homography for each image to ensure reproducibility
+ for file in tqdm(self.files):
+ npz_file = Path(file).with_suffix('.npz')
+ if not npz_file.exists() or self.overwrite:
+ image = cv2.imread(file, 0)
+ image = cv2.resize(image, self.size)
+ image = np.array(image, dtype=np.float32)/255.0
+
+ w, h = image.shape[:2]
+ H = sample_homography(self.size, **homography_params)[0]
+ warped_image = cv2.warpPerspective(image, H, self.size)
+ warped_image = np.array(warped_image, dtype=np.float32)
+
+ data = {
+ 'ref_image': image,
+ 'target_image': warped_image,
+ 'homo_mat': H,
+ }
+
+ np.savez(npz_file, ref_image=image, target_image=warped_image, homo_mat=H)
+
+ self.npz_files.append(npz_file)
+
+ def get_dataset(self):
+ return self.npz_files
+
+ def get_images(self):
+ return self.files
+
+ def len_dataset(self):
+ return len(self.files)
+
+ def __getitem__(self, idx):
+ npz_file = self.npz_files(idx)
+ data = np.load(npz_file)
+
+ return data
diff --git a/scalelsd/ssl/datasets/dataset_util.py b/scalelsd/ssl/datasets/dataset_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..c24642a94b444d2eb1d898de4e240fba54724f5f
--- /dev/null
+++ b/scalelsd/ssl/datasets/dataset_util.py
@@ -0,0 +1,94 @@
+"""
+The interface of initializing different datasets.
+"""
+from .synthetic_dataset import SyntheticShapes,synthetic_collate_fn
+from .wireframe_dataset import WireframeDataset,wireframe_collate_fn
+from .yorkurban_dataset import YorkUrbanDataset,yorkurban_collate_fn
+from .images_dataset import ImageCollections, images_collate_fn
+# from .holicity_dataset import HolicityDataset
+# from .merge_dataset import MergeDataset
+import torch.utils.data.dataloader as torch_loader
+try:
+ from .official_yorkurban_dataset import YorkUrban
+except:
+ pass
+
+from .nyu_dataset import NYU
+from .rdnim_dataset import RDNIM
+from .hpatches_dataset import HPatches
+
+def get_dataset(mode="train", dataset_cfg=None, homoadp=False, **kwargs):
+ """ Initialize different dataset based on a configuration. """
+ # Check dataset config is given
+ if dataset_cfg is None:
+ raise ValueError("[Error] The dataset config is required!")
+
+ # Synthetic dataset
+ if dataset_cfg["dataset_name"] == "synthetic_shape":
+ dataset = SyntheticShapes(
+ mode, dataset_cfg
+ )
+ # Get the collate_fn
+ # from sold2.dataset.synthetic_dataset import synthetic_collate_fn
+ collate_fn = synthetic_collate_fn
+
+ # Wireframe dataset
+ elif dataset_cfg["dataset_name"] == "wireframe":
+ dataset = WireframeDataset(
+ mode, dataset_cfg
+ )
+
+ # Get the collate_fn
+ collate_fn = wireframe_collate_fn
+ elif dataset_cfg["dataset_name"] == "yorkurban":
+ dataset = YorkUrbanDataset(
+ mode, dataset_cfg
+ )
+
+ # Get the collate_fn
+ collate_fn = yorkurban_collate_fn
+ # Holicity dataset
+ elif dataset_cfg["dataset_name"] == "holicity":
+ dataset = HolicityDataset(
+ mode, dataset_cfg
+ )
+
+ # Get the collate_fn
+ from sold2.dataset.holicity_dataset import holicity_collate_fn
+ collate_fn = holicity_collate_fn
+
+ # Dataset merging several datasets in one
+ elif dataset_cfg["dataset_name"] == "merge":
+ dataset = MergeDataset(
+ mode, dataset_cfg
+ )
+
+ # Get the collate_fn
+ from sold2.dataset.holicity_dataset import holicity_collate_fn
+ collate_fn = holicity_collate_fn
+ elif dataset_cfg["dataset_name"] == "general":
+ dataset = ImageCollections(mode, dataset_cfg, homoadp=homoadp,**kwargs)
+ collate_fn = images_collate_fn
+
+
+ ## for the official YorkUrbanDB
+ elif dataset_cfg["dataset_name"] == "official_yorkurban":
+ dataset = YorkUrban(mode, dataset_cfg)
+ collate_fn = torch_loader.default_collate
+
+ ## for the NYU_depth_v2
+ elif dataset_cfg["dataset_name"] == "nyu":
+ dataset = NYU(mode, dataset_cfg)
+ collate_fn = torch_loader.default_collate
+
+ elif dataset_cfg["dataset_name"] == "rdnim":
+ dataset = RDNIM(dataset_cfg)
+ collate_fn = torch_loader.default_collate
+ elif dataset_cfg["dataset_name"] == "hpatches":
+ dataset = HPatches(mode, dataset_cfg)
+ collate_fn = torch_loader.default_collate
+ else:
+ raise ValueError(
+ "[Error] The dataset '%s' is not supported" % dataset_cfg["dataset_name"])
+
+ return dataset, collate_fn
diff --git a/scalelsd/ssl/datasets/hpatches_dataset.py b/scalelsd/ssl/datasets/hpatches_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..b02ce44c38a406837516685ada9ad28176d0ec1e
--- /dev/null
+++ b/scalelsd/ssl/datasets/hpatches_dataset.py
@@ -0,0 +1,126 @@
+"""
+HPatches sequences dataset, to perform homography estimation and
+evaluate basic line detection metrics.
+"""
+import os
+import numpy as np
+import torch
+import cv2
+from pathlib import Path
+from torch.utils.data import Dataset, DataLoader
+
+from ..config.project_config import Config as cfg
+
+
+
+class HPatches(torch.utils.data.Dataset):
+ def __init__(self, mode='test', config=None):
+ assert mode in ['test', 'export']
+
+ self.conf = config
+ self.root_dir = Path(cfg.hpatches_dataroot)
+ folder_paths = [x for x in self.root_dir.iterdir() if x.is_dir()]
+ self.data = []
+ for path in folder_paths:
+ if config['alteration'] == 'i' and path.stem[0] != 'i':
+ continue
+ if config['alteration'] == 'v' and path.stem[0] != 'v':
+ continue
+ if mode == 'test':
+ for i in range(2, 7):
+ ref_path = Path(path, "1.ppm")
+ target_path = Path(path, str(i) + '.ppm')
+ self.data += [{
+ "ref_name": str(ref_path.parent.stem + "_" + ref_path.stem),
+ "ref_img_path": str(ref_path),
+ "target_name": str(target_path.parent.stem + "_" + target_path.stem),
+ "target_img_path": str(target_path),
+ "H": np.loadtxt(str(Path(path, "H_1_" + str(i)))),
+ }]
+ else:
+ for i in range(1, 7):
+ ref_path = Path(path, str(i) + '.ppm')
+ self.data += [{
+ "ref_name": str(ref_path.parent.stem + "_" + ref_path.stem),
+ "ref_img_path": str(ref_path)}]
+
+ def get_dataset(self):
+ return self
+
+ def __getitem__(self, idx):
+ img0_path = self.data[idx]['ref_img_path']
+ img0 = cv2.imread(img0_path, 0)
+ img_size = img0.shape
+
+ if max(img_size) > self.conf['max_side']:
+ s = self.conf['max_side'] / max(img_size)
+ h_s = int(img_size[0] * s)
+ w_s = int(img_size[1] * s)
+ img0 = cv2.resize(img0, (w_s, h_s), interpolation=cv2.INTER_AREA)
+
+ # Normalize the image in [0, 1]
+ img0 = img0.astype(float) / 255.
+ img0 = torch.tensor(img0[None], dtype=torch.float32)
+ outputs = {'image': img0, 'image_path': img0_path,
+ 'name': self.data[idx]['ref_name']}
+
+ if 'target_name' in self.data[idx]:
+ img1_path = self.data[idx]['target_img_path']
+ img1 = cv2.imread(img1_path, 0)
+ H = self.data[idx]['H']
+
+ if max(img_size) > self.conf['max_side']:
+ img1 = cv2.resize(img1, (w_s, h_s),
+ interpolation=cv2.INTER_AREA)
+ H = self.adapt_homography_to_preprocessing(
+ H, img_size, img_size, (h_s, w_s))
+
+ # Normalize the image in [0, 1]
+ img1 = img1.astype(float) / 255.
+ img1 = torch.tensor(img1[None], dtype=torch.float)
+ H = torch.tensor(H, dtype=torch.float)
+
+ outputs['warped_image'] = img1
+ outputs['warped_image_path'] = img1_path
+ outputs['warped_name'] = self.data[idx]['target_name']
+ outputs['H'] = H
+
+
+ # root='/home/kezeran/code/hawpv4-dev/data-ssl/0images'
+ # try:
+ # cv2.imwrite(f'{root}/img_{idx}.png', cv2.imread(img0_path))
+ # cv2.imwrite(f'{root}/img_{idx}_w.png', cv2.imread(img1_path))
+ # except:
+ # pass
+
+ return outputs
+
+ def __len__(self):
+ return len(self.data)
+
+ def adapt_homography_to_preprocessing(self, H, img_shape1, img_shape2,
+ target_size):
+ source_size1 = np.array(img_shape1, dtype=float)
+ source_size2 = np.array(img_shape2, dtype=float)
+ target_size = np.array(target_size)
+
+ # Get the scaling factor in resize
+ scale1 = np.amax(target_size / source_size1)
+ scaling1 = np.diag([1. / scale1, 1. / scale1, 1.]).astype(float)
+ scale2 = np.amax(target_size / source_size2)
+ scaling2 = np.diag([scale2, scale2, 1.]).astype(float)
+
+ # Get the translation params in crop
+ pad_y1 = (source_size1[0] * scale1 - target_size[0]) / 2.
+ pad_x1 = (source_size1[1] * scale1 - target_size[1]) / 2.
+ translation1 = np.array([[1., 0., pad_x1],
+ [0., 1., pad_y1],
+ [0., 0., 1.]], dtype=float)
+ pad_y2 = (source_size2[0] * scale2 - target_size[0]) / 2.
+ pad_x2 = (source_size2[1] * scale2 - target_size[1]) / 2.
+ translation2 = np.array([[1., 0., -pad_x2],
+ [0., 1., -pad_y2],
+ [0., 0., 1.]], dtype=float)
+
+ return translation2 @ scaling2 @ H @ scaling1 @ translation1
+
diff --git a/scalelsd/ssl/datasets/images_dataset.py b/scalelsd/ssl/datasets/images_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..18b54f196acd7f1c45840a2ecef45902a761fb36
--- /dev/null
+++ b/scalelsd/ssl/datasets/images_dataset.py
@@ -0,0 +1,400 @@
+try:
+ from pcache_fileio import fileio
+except:
+ pass
+
+import os
+import os.path as osp
+import glob
+import math
+import copy
+from skimage.io import imread
+from skimage import color
+import PIL
+from PIL import Image
+import numpy as np
+import h5py
+import cv2
+import pickle
+import torch
+import torch.utils.data.dataloader as torch_loader
+from torch.utils.data import Dataset
+from torchvision import transforms
+from pathlib import Path
+import json
+
+from ..config.project_config import Config as cfg
+from .transforms import photometric_transforms as photoaug
+from .transforms import homographic_transforms as homoaug
+from .transforms.utils import random_scaling
+from .synthetic_util import get_line_heatmap
+from ..misc.train_utils import parse_h5_data
+from ..misc.geometry_utils import warp_points, mask_points
+from tqdm import tqdm
+
+def images_collate_fn(batch):
+ """ Customized collate_fn for wireframe dataset. """
+ batch_keys = ["image", "junction_map", "valid_mask", "heatmap",
+ "heatmap_pos", "heatmap_neg", "homography",
+ "line_points", "line_indices"]
+ list_keys = ["junctions", "line_map", "line_map_pos",
+ "line_map_neg", "file_key","fname","image_origin","uuid"]
+
+ outputs = {}
+ for data_key in batch[0].keys():
+ batch_match = sum([_ == data_key for _ in batch_keys])
+ list_match = sum([_ == data_key for _ in list_keys])
+ # print(batch_match, list_match)
+ if batch_match > 0 and list_match == 0:
+ outputs[data_key] = torch_loader.default_collate(
+ [b[data_key] for b in batch])
+ elif batch_match == 0 and list_match > 0:
+ outputs[data_key] = [b[data_key] for b in batch]
+ elif batch_match == 0 and list_match == 0:
+ continue
+ else:
+ raise ValueError(
+ "[Error] A key matches batch keys and list keys simultaneously.")
+
+ return outputs
+
+class ImageCollections(Dataset):
+ def __init__(self, mode, config, homoadp=False,homoadp_resume=False):
+ super(ImageCollections, self).__init__()
+ if config is None:
+ self.config = self.get_default_config()
+ else:
+ self.config = config
+ h5path = config.get('gt_source_train',None)
+ self.json_list = None
+ self.homoadp = homoadp
+ self.homoadp_resume = homoadp_resume
+
+ if self.config['img_reg_exp'] == 'all':
+ self.config['img_reg_exp'] = []
+ for i in range(998):
+ self.config['img_reg_exp'].append(f'sa_{i:06d}/images/*.jpg')
+
+ if h5path is not None and h5path.endswith('.h5'):
+ self.h5path = osp.join(cfg.EXPORT_ROOT,'export_datasets',h5path)
+ with h5py.File(self.h5path,'r') as f:
+ self.filenames = [k.decode('UTF-8') for k in f['filenames']]
+ self.filenames = [osp.join(cfg.EXPORT_ROOT,f) for f in self.filenames]
+ elif h5path is not None and h5path.endswith('.jsons'):
+ self.use_json = True
+ self.h5path = osp.join(cfg.EXPORT_ROOT,'export_datasets',h5path)
+ #json_list = glob.glob(self.h5path+'/*.json')
+ json_list = []
+ for exp in tqdm(self.config['img_reg_exp']):
+ _json_regexp = Path(exp).with_suffix('.json')
+ _jsons = glob.glob(osp.join(self.h5path,str(_json_regexp)))
+ json_list.extend(_jsons)
+
+ if cfg.DATASET_ROOT.startswith('pcache'):
+
+ if osp.isfile(Path(h5path).with_suffix('.pcache')) and (self.homoadp_resume or not self.homoadp):
+ with open(Path(h5path).with_suffix('.pcache'),'r') as _f:
+ filenames = _f.readlines()
+ filenames = [ x.rstrip('\n') for x in filenames ]
+ else:
+ self.folder_regexp = []
+ filenames = []
+ print('Loading from pcache......')
+ for exp in tqdm(self.config['img_reg_exp']):
+ _path = Path(osp.join(self.config['dataset_root'][0],exp))
+ _p = osp.join(cfg.DATASET_ROOT,str(_path.parent))
+
+ _e = _path.suffix
+ _list = [osp.basename(_) for _ in os.listdir(_p) if _.endswith(_e)]
+ _list = [osp.join(_p,_) for _ in _list]
+ filenames.extend(_list)
+
+ with open(Path(h5path).with_suffix('.pcache'),'w') as _f:
+ _f.writelines('\n'.join(filenames))
+ else:
+ self.folder_regexp = [osp.join(cfg.DATASET_ROOT,self.config['dataset_root'][0],exp) for exp in self.config['img_reg_exp']]
+ filenames = sum([glob.glob(exp) for exp in self.folder_regexp],[])
+ filenames = [Path(f) for f in filenames]
+
+ self.dataset_root = osp.join(cfg.DATASET_ROOT,self.config['dataset_root'][0])
+ filedict = {str(Path(osp.relpath(f,self.dataset_root)).with_suffix('')): f for f in filenames}
+
+ jsondict = {str(Path(osp.relpath(j,h5path)).with_suffix('')): j for j in json_list}
+ self.filenames = []
+ self.json_list = []
+
+ if self.homoadp:
+ for k in filedict.keys():
+ if k in jsondict and self.homoadp_resume:
+ continue
+ else:
+ self.filenames.append(str(filedict[k]))
+ self.h5path = None
+ self.use_json = False
+ print(f"Found {len(json_list)} json files from the folder")
+ print(f"Total images are reduced from {len(filenames)} to {len(self.filenames)}")
+ else:
+ for k in filedict.keys():
+ if k in jsondict:
+ self.filenames.append(str(filedict[k]))
+ self.json_list.append(str(jsondict[k]))
+ else:
+ self.folder_regexp = [osp.join(cfg.DATASET_ROOT,self.config['dataset_root'][0],exp) for exp in self.config['img_reg_exp']]
+ self.filenames = sum([glob.glob(exp) for exp in self.folder_regexp],[])
+ self.h5path = None
+
+ self.default_config = self.get_default_config()
+
+ self.dataset_name = self.config['alias']
+
+ self.size = self.config['preprocessing']['resize']
+
+ print("Found %d images in %s" % (len(self),self.config['dataset_root']))
+
+ self.num_pad = int(math.ceil(math.log10(len(self))))+1 if len(self)>0 else 0
+
+ def __len__(self):
+ return len(self.filenames)
+
+ def get_padded_filename(self, num_pad, idx):
+ file_len = len("%d" % (idx))
+ filename = "0" * (num_pad - file_len) + "%d" % (idx)
+ return filename
+
+ def train_preprocessing(self, data, numpy=False):
+ """ Train preprocessing for the dataset. """
+ image = data['image']
+ junctions = data.get('junctions',None)
+ image_size = image.shape[:2]
+ if not(list(image_size) == self.config['preprocessing']['resize']):
+ size_old = list(image.shape)[:2]
+
+ image = cv2.resize(image, tuple(self.config['preprocessing']['resize'][::-1]), interpolation=cv2.INTER_LINEAR)
+
+ scales = (image.shape[0] / size_old[0], image.shape[1] / size_old[1])
+
+ if junctions is not None:
+ junctions *= torch.tensor(scales).reshape(1, 2)
+
+ if self.config['augmentation']['photometric']['enable']:
+ photo_trans_list = self.get_photo_transform()
+ ### Apply photometric transforms
+ np.random.shuffle(photo_trans_list)
+ image_transform = transforms.Compose(photo_trans_list + [photoaug.normalize_image()])
+ else:
+ image_transform = photoaug.normalize_image()
+
+ image = image_transform(image)
+
+ if self.config['augmentation']['homographic']['enable']:
+ homo_trans = self.get_homo_transform()
+ outputs = homo_trans(image, junctions, data['line_map'])
+ junctions = outputs["junctions"]
+ image = outputs["warped_image"]
+ line_map = outputs['line_map']
+ data['line_map'] = torch.tensor(line_map)
+ data['valid_mask'] = outputs['valid_mask']
+
+ data['image'] = torch.from_numpy(image)[None]
+ if junctions is not None:
+ data['junctions'] = torch.from_numpy(junctions).float()
+
+
+ return data
+
+ def get_homo_transform(self):
+ """ Get homographic transforms (according to the config). """
+ # Get homographic transforms for image
+ homo_config = self.config["augmentation"]["homographic"]["params"]
+ if not self.config["augmentation"]["homographic"]["enable"]:
+ raise ValueError(
+ "[Error] Homographic augmentation is not enabled.")
+
+ # Parse the homographic transforms
+ image_shape = self.config["preprocessing"]["resize"]
+
+ # Compute the min_label_len from config
+ try:
+ min_label_tmp = self.config["generation"]["min_label_len"]
+ except:
+ min_label_tmp = None
+
+ # float label len => fraction
+ if isinstance(min_label_tmp, float): # Skip if not provided
+ min_label_len = min_label_tmp * min(image_shape)
+ # int label len => length in pixel
+ elif isinstance(min_label_tmp, int):
+ scale_ratio = (self.config["preprocessing"]["resize"]
+ / self.config["generation"]["image_size"][0])
+ min_label_len = (self.config["generation"]["min_label_len"]
+ * scale_ratio)
+ # if none => no restriction
+ else:
+ min_label_len = 0
+
+ # Initialize the transform
+ homographic_trans = homoaug.homography_transform(
+ image_shape, homo_config, 0, min_label_len)
+
+ return homographic_trans
+
+ def get_photo_transform(self):
+ """ Get list of photometric transforms (according to the config). """
+ # Get the photometric transform config
+ photo_config = self.config["augmentation"]["photometric"]
+ if not photo_config["enable"]:
+ raise ValueError(
+ "[Error] Photometric augmentation is not enabled.")
+
+ # Parse photometric transforms
+ trans_lst = self.parse_transforms(photo_config["primitives"],
+ photoaug.available_augmentations)
+ trans_config_lst = [photo_config["params"].get(p, {})
+ for p in trans_lst]
+
+ # List of photometric augmentation
+ photometric_trans_lst = [
+ getattr(photoaug, trans)(**conf) \
+ for (trans, conf) in zip(trans_lst, trans_config_lst)
+ ]
+
+ return photometric_trans_lst
+
+ def parse_transforms(self, names, all_transforms):
+ """ Parse the transform. """
+ trans = all_transforms if (names == 'all') \
+ else (names if isinstance(names, list) else [names])
+ assert set(trans) <= set(all_transforms)
+ return trans
+
+ def check_files(self):
+ h5path = self.config.get('gt_source_train',None)
+ valid_filenames = []
+ for filename in self.filenames:
+ try:
+ image_origin = np.array(PIL.Image.open(filename))
+ valid_filenames.append(filename)
+ except IOError:
+ print(f"Unable to load image from path: {filename}")
+
+ new_pcache_path = Path(h5path).with_name(f"{Path(h5path).stem}_filtered.pcache")
+ with open(new_pcache_path, 'w') as _f:
+ for filename in valid_filenames:
+ _f.write(f"{filename}\n")
+
+ def check_health(self):
+ is_healthy = True
+ image_fail_list = []
+ json_fail_list = []
+ for idx in tqdm(range(len(self))):
+ #try:
+ # image_origin = np.array(PIL.Image.open(self.filenames[idx]))
+ #except:
+ # is_healthy = False
+ # print(f'The image {self.filenames[idx]} is broken.')
+ # image_fail_list.append(self.filenames[idx])
+
+ if self.h5path is not None and self.json_list is not None:
+ try:
+ with open(self.json_list[idx],'r') as f:
+ data = json.load(f)
+ except:
+ is_healthy = False
+ print(f'The image {self.filenames[idx]} is broken.')
+ json_fail_list.append(self.json_list[idx])
+ return {
+ 'images': image_fail_list,
+ 'jsons': json_fail_list,
+ 'status': is_healthy
+ }
+
+ def __getitem__(self, idx):
+ fname = osp.basename(self.filenames[idx])
+ #image_origin = cv2.imread(self.filenames[idx])
+ try:
+ image_origin = np.array(PIL.Image.open(self.filenames[idx]))
+ except:
+ image_origin = np.array(PIL.Image.open('hawp/ssl/config/exports/sa1b/00030043_0.png')) # deal with the failed case
+
+ if self.config['gray_scale']:
+ image = cv2.cvtColor(image_origin, cv2.COLOR_BGR2GRAY)
+ else:
+ image = cv2.cvtColor(image_origin, cv2.COLOR_BGR2RGB)
+
+ # image = np.array(image,dtype=np.float32)/255.0
+
+ data = {
+ 'fname': self.filenames[idx],
+ 'image': image,
+ # 'image': torch.from_numpy(image)[None],
+ # 'valid_mask': torch.ones(self.size,dtype=torch.float32)[None],
+ 'image_origin': image_origin,
+ }
+ data['uuid'] = osp.relpath(self.filenames[idx],self.dataset_root)
+
+ if self.h5path is not None and self.json_list is None:
+ with h5py.File(self.h5path,'r') as f:
+ gt_key = self.get_padded_filename(self.num_pad,idx)
+ exported_label = parse_h5_data(f[gt_key])
+ junctions = torch.tensor(exported_label['junctions']).float()
+ edges = torch.tensor(exported_label['edges']).long()
+ lines = junctions[edges]
+ junctions_valid = torch.zeros(len(junctions),dtype=torch.bool)
+ junctions_valid[edges.unique()] = 1
+ junctions_idx = -torch.ones(len(junctions),dtype=torch.long)
+ junctions_idx[junctions_valid] = torch.arange(junctions_valid.sum())
+ edges_remapped = junctions_idx[edges]
+ junctions = junctions[junctions_valid]
+ lines_remapped = junctions[edges_remapped]
+ line_map = torch.zeros(junctions.shape[0],junctions.shape[0],dtype=torch.float32)
+ if len(edges_remapped) > 0:
+ line_map[edges_remapped[:,0],edges_remapped[:,1]] = 1
+ line_map[edges_remapped[:,1],edges_remapped[:,0]] = 1
+
+ data['line_map'] = line_map
+ data['junctions'] = junctions[:,[1,0]]
+ elif self.h5path is not None and self.json_list is not None:
+ with open(self.json_list[idx],'r') as f:
+ json_data = json.load(f)
+ junctions = torch.tensor(json_data['junctions']).float()
+ if junctions.shape[0] == 0:
+ junctions = torch.zeros((1,2)).float()
+ edges = torch.tensor(json_data['edges']).long()
+ lines = junctions[edges]
+ junctions_valid = torch.zeros(len(junctions),dtype=torch.bool)
+ junctions_valid[edges.unique()] = 1
+ junctions_idx = -torch.ones(len(junctions),dtype=torch.long)
+ junctions_idx[junctions_valid] = torch.arange(junctions_valid.sum())
+ edges_remapped = junctions_idx[edges]
+ junctions = junctions[junctions_valid]
+ lines_remapped = junctions[edges_remapped]
+ line_map = torch.zeros(junctions.shape[0],junctions.shape[0],dtype=torch.float32)
+ if len(edges_remapped) > 0:
+ line_map[edges_remapped[:,0],edges_remapped[:,1]] = 1
+ line_map[edges_remapped[:,1],edges_remapped[:,0]] = 1
+
+ data['line_map'] = line_map
+ data['junctions'] = junctions[:,[1,0]]
+ else:
+ data['valid_mask'] = torch.ones(self.size,dtype=torch.float32)[None]
+
+ return self.train_preprocessing(data)
+ return data # TODO: remove this line
+
+ def get_default_config(self):
+ return {
+ "dataset_name": "images",
+ "add_augmentation_to_all_splits": False,
+ "preprocessing": {
+ "resize": [512,512],
+ "blur_size": 11,
+ },
+ "augmentation": {
+ "photometric": {
+ "enable": False
+ },
+ "homographic": {
+ "enable": False
+ }
+ }
+ }
\ No newline at end of file
diff --git a/scalelsd/ssl/datasets/nyu_dataset.py b/scalelsd/ssl/datasets/nyu_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d7333fb42ec2f5d8ad06424976a684989a4b5b4
--- /dev/null
+++ b/scalelsd/ssl/datasets/nyu_dataset.py
@@ -0,0 +1,118 @@
+""" YorkUrban dataset for VP estimation evaluation. """
+
+import os
+import csv
+import numpy as np
+import torch
+import cv2
+import scipy.io
+from skimage.io import imread
+from torch.utils.data import Dataset, DataLoader
+from pathlib import Path
+
+from ..config.project_config import Config as cfg
+
+
+def unproject_vp_to_world(vp, K):
+ """ Convert the VPs from homogenous format in the image plane
+ to world direction. """
+ proj_vp = (np.linalg.inv(K) @ vp.T).T
+ proj_vp[:, 1] *= -1
+ proj_vp /= np.linalg.norm(proj_vp, axis=1, keepdims=True)
+ return proj_vp
+
+class NYU(torch.utils.data.Dataset):
+ def __init__(self, mode='test', config=None):
+
+ # assert mode in ['val', 'test']
+
+ # Extract the image names
+ num_imgs = 1449
+ val_size = -49
+
+ self.root_dir = cfg.nyu_dataroot
+ self.img_paths = [os.path.join(self.root_dir, 'images', 'nyu_rgb_'+str(i+1).zfill(4) + '.png')
+ for i in range(num_imgs)]
+ self.vps_paths = [os.path.join(self.root_dir, 'vps', 'vps_' + str(i).zfill(4) + '.csv')
+ for i in range(num_imgs)]
+ self.lines_paths = [os.path.join(self.root_dir, 'labelled_lines', 'labelled_lines_' + str(i).zfill(4) + '.csv')
+ for i in range(num_imgs)]
+ self.img_names = [str(i).zfill(4) for i in range(num_imgs)]
+
+ # Separate validation and test
+ if mode == 'val':
+ self.img_paths = self.img_paths[-val_size:]
+ self.vps_paths = self.vps_paths[-val_size:]
+ self.lines_paths = self.lines_paths[-val_size:]
+ self.img_names = self.img_names[-val_size:]
+ elif mode == 'test':
+ self.img_paths = self.img_paths[:-val_size]
+ self.vps_paths = self.vps_paths[:-val_size]
+ self.lines_paths = self.lines_paths[:-val_size]
+ self.img_names = self.img_names[:-val_size]
+
+ # Load the intrinsics
+ fx_rgb = 5.1885790117450188e+02
+ fy_rgb = 5.1946961112127485e+02
+ cx_rgb = 3.2558244941119034e+02
+ cy_rgb = 2.5373616633400465e+02
+ self.K = torch.tensor([[fx_rgb, 0, cx_rgb],
+ [0, fy_rgb, cy_rgb],
+ [0, 0, 1]])
+
+ def get_dataset(self, split):
+ return self
+
+ def __getitem__(self, idx):
+ img_path = self.img_paths[idx]
+ name = str(Path(img_path).stem)
+ img = cv2.imread(img_path)
+
+ # Load the GT VPs
+ vps = []
+ with open(self.vps_paths[idx]) as csv_file:
+ reader = csv.reader(csv_file, delimiter=' ')
+ for ri, row in enumerate(reader):
+ if ri == 0:
+ continue
+ vps.append([float(row[1]), float(row[2]), 1.])
+ vps = unproject_vp_to_world(np.array(vps), self.K.numpy())
+
+ lines = []
+ with open(self.lines_paths[idx]) as csv_file:
+ reader = csv.reader(csv_file, delimiter=' ')
+ for ri, row in enumerate(reader):
+ if ri == 0:
+ continue
+ lines.append([float(row[1]), float(row[2]), 1.])
+
+ # Normalize the images in [0, 1]
+ # img = img.astype(float) / 255.
+
+ # Convert to torch tensors
+ # img = torch.tensor(img[None], dtype=torch.float)
+ vps = torch.tensor(vps, dtype=torch.float)
+ lines = torch.tensor(lines, dtype=torch.float)
+
+ data = {'image': img,
+ 'image_path': img_path,
+ 'name': name,
+ 'gt_lines': lines,
+ 'vps': vps,
+ 'K': self.K
+ }
+
+ return data
+
+ def __len__(self):
+ return len(self.img_paths)
+
+ # Overwrite the parent data loader to handle custom split
+ def get_data_loader(self, split, shuffle=False):
+ """Return a data loader for a given split."""
+ assert split in ['val', 'test', 'export']
+ batch_size = self.conf.get(split+'_batch_size')
+ num_workers = self.conf.get('num_workers', batch_size)
+ return DataLoader(self.get_dataset(split), batch_size=batch_size,
+ shuffle=False, pin_memory=True,
+ num_workers=num_workers)
\ No newline at end of file
diff --git a/scalelsd/ssl/datasets/official_yorkurban_dataset.py b/scalelsd/ssl/datasets/official_yorkurban_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf8ad023ee530fb907250fc44c52c5c82f826444
--- /dev/null
+++ b/scalelsd/ssl/datasets/official_yorkurban_dataset.py
@@ -0,0 +1,100 @@
+""" YorkUrban dataset for VP estimation evaluation. """
+
+import os
+import numpy as np
+import torch
+import cv2
+import scipy.io
+from skimage.io import imread
+from torch.utils.data import Dataset, DataLoader
+from pathlib import Path
+
+from ..config.project_config import Config as cfg
+
+
+
+class YorkUrban(torch.utils.data.Dataset):
+ def __init__(self, mode='train', config=None):
+
+ assert mode in ['train', 'val', 'test']
+
+ # Extract the image names
+ self.root_dir = cfg.official_yorkurban_dataroot
+ self.img_names = [name for name in os.listdir(self.root_dir)
+ if os.path.isdir(os.path.join(self.root_dir, name))]
+ assert len(self.img_names) == 102 ## 102 categories in total
+
+ # Separate validation and test
+ split_file = os.path.join(self.root_dir,
+ 'ECCV_TrainingAndTestImageNumbers.mat')
+ split_mat = scipy.io.loadmat(split_file)
+ if mode == 'test':
+ valid_set = split_mat['testSetIndex'][:, 0] - 1
+ else:
+ valid_set = split_mat['trainingSetIndex'][:, 0] - 1
+ self.img_names = np.array(self.img_names)[valid_set]
+ assert len(self.img_names) == 51
+
+ # Load the intrinsics
+ K_file = os.path.join(self.root_dir, 'cameraParameters.mat')
+ K_mat = scipy.io.loadmat(K_file)
+ f = K_mat['focal'][0, 0] / K_mat['pixelSize'][0, 0]
+ p_point = K_mat['pp'][0] - 1 # -1 to convert to 0-based conv
+ self.K = torch.tensor([[f, 0, p_point[0]],
+ [0, f, p_point[1]],
+ [0, 0, 1]])
+
+ def __len__(self):
+ return len(self.img_names)
+
+ def __getitem__(self, idx):
+ img_path = os.path.join(self.root_dir, self.img_names[idx],
+ f'{self.img_names[idx]}.jpg')
+ name = str(Path(img_path).stem)
+ img = cv2.imread(img_path)
+
+ # Load the GT lines and VP association
+ lines_file = os.path.join(self.root_dir, self.img_names[idx],
+ f'{self.img_names[idx]}LinesAndVP.mat')
+ lines_mat = scipy.io.loadmat(lines_file)
+ lines = lines_mat['lines'].reshape(-1, 2, 2)[:, :, [1, 0]] - 1
+ vp_association = lines_mat['vp_association'][:, 0] - 1
+
+ # Load the VPs (non orthogonal ones)
+ vp_file = os.path.join(
+ self.root_dir, self.img_names[idx],
+ f'{self.img_names[idx]}GroundTruthVP_CamParams.mat')
+ vps = scipy.io.loadmat(vp_file)['vp'].T
+
+ # Keep only the relevant VPs
+ unique_vps = np.unique(vp_association)
+ vps = vps[unique_vps]
+ for i, index in enumerate(unique_vps):
+ vp_association[vp_association == index] = i
+
+ # Convert to torch tensors
+ # img = torch.tensor(img[None], dtype=torch.float)
+ lines = torch.tensor(lines.astype(float), dtype=torch.float)
+ vps = torch.tensor(vps, dtype=torch.float)
+ vp_association = torch.tensor(vp_association, dtype=torch.int)
+
+ data = {'image': img,
+ 'image_path': img_path,
+ 'name': name,
+ 'gt_lines': lines,
+ 'vps': vps,
+ 'vp_association': vp_association,
+ 'K': self.K
+ }
+
+ return data
+
+ # Overwrite the parent data loader to handle custom collate_fn
+ def get_data_loader(self, split, shuffle=False):
+ """Return a data loader for a given split."""
+ assert split in ['val', 'test']
+ batch_size = self.conf.get(split+'_batch_size')
+ num_workers = self.conf.get('num_workers', batch_size)
+ return DataLoader(self.get_dataset(split), batch_size=batch_size,
+ shuffle=False, pin_memory=True,
+ num_workers=num_workers)
diff --git a/scalelsd/ssl/datasets/rdnim_dataset.py b/scalelsd/ssl/datasets/rdnim_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..632043964eef5c6a65d718503b47fcdda9f2d3a3
--- /dev/null
+++ b/scalelsd/ssl/datasets/rdnim_dataset.py
@@ -0,0 +1,110 @@
+""" Rotated Day-Night Image Matching dataset. """
+
+import os
+import numpy as np
+import torch
+import cv2
+import csv
+from pathlib import Path
+from torch.utils.data import Dataset, DataLoader
+
+from ..config.project_config import Config as cfg
+
+def read_timestamps(text_file):
+ """
+ Read a text file containing the timestamps of images
+ and return a dictionary matching the name of the image
+ to its timestamp.
+ """
+ timestamps = {'name': [], 'date': [], 'hour': [],
+ 'minute': [], 'time': []}
+ with open(text_file, 'r') as csvfile:
+ reader = csv.reader(csvfile, delimiter=' ')
+ for row in reader:
+ timestamps['name'].append(row[0])
+ timestamps['date'].append(row[1])
+ hour = int(row[2])
+ timestamps['hour'].append(hour)
+ minute = int(row[3])
+ timestamps['minute'].append(minute)
+ timestamps['time'].append(hour + minute / 60.)
+ return timestamps
+
+class RDNIM(torch.utils.data.Dataset):
+ default_conf = {
+ 'dataset_dir': 'RDNIM',
+ 'reference': 'day',
+ }
+
+ def __init__(self, conf):
+ self._root_dir = Path(cfg.rdnim_dataroot)
+ ref = conf['reference']
+
+ # Extract the timestamps
+ timestamp_files = [p for p
+ in Path(self._root_dir, 'time_stamps').iterdir()]
+ timestamps = {}
+ for f in timestamp_files:
+ id = f.stem
+ timestamps[id] = read_timestamps(str(f))
+
+ # Extract the reference images paths
+ references = {}
+ seq_paths = [p for p in Path(self._root_dir, 'references').iterdir()]
+ for seq in seq_paths:
+ id = seq.stem
+ references[id] = str(Path(seq, ref + '.jpg'))
+
+ # Extract the images paths and the homographies
+ seq_path = [p for p in Path(self._root_dir, 'images').iterdir()]
+ self._files = []
+ for seq in seq_path:
+ id = seq.stem
+ images_path = [x for x in seq.iterdir() if x.suffix == '.jpg']
+ for img in images_path:
+ timestamp = timestamps[id]['time'][
+ timestamps[id]['name'].index(img.name)]
+ H = np.loadtxt(str(img)[:-4] + '.txt').astype(float)
+ self._files.append({
+ 'img': str(img),
+ 'ref': str(references[id]),
+ 'H': H,
+ 'timestamp': timestamp})
+
+ def __getitem__(self, item):
+ img0_path = self._files[item]['ref']
+ img0 = cv2.imread(img0_path, 0)
+ img1_path = self._files[item]['img']
+ img1 = cv2.imread(img1_path, 0)
+ img_size = img0.shape[:2]
+ H = self._files[item]['H']
+
+ # Normalize the images in [0, 1]
+ img0 = img0.astype(float) / 255.
+ img1 = img1.astype(float) / 255.
+
+ img0 = torch.tensor(img0[None], dtype=torch.float)
+ img1 = torch.tensor(img1[None], dtype=torch.float)
+ H = torch.tensor(H, dtype=torch.float)
+
+ return {'image': img0, 'warped_image': img1, 'H': H,
+ 'timestamp': self._files[item]['timestamp'],
+ 'image_path': img0_path, 'warped_image_path': img1_path}
+
+ def __len__(self):
+ return len(self._files)
+
+ def get_dataset(self, split):
+ assert split in ['test']
+ return self
+
+ # Overwrite the parent data loader to handle custom collate_fn
+ def get_data_loader(self, split, shuffle=False):
+ """Return a data loader for a given split."""
+ assert split in ['test']
+ batch_size = self.conf.get(split+'_batch_size')
+ num_workers = self.conf.get('num_workers', batch_size)
+ return DataLoader(self, batch_size=batch_size,
+ shuffle=shuffle or split == 'train',
+ pin_memory=True, num_workers=num_workers)
+
diff --git a/scalelsd/ssl/datasets/synthetic_dataset.py b/scalelsd/ssl/datasets/synthetic_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..eef63999685a8c619fe212f13a72526dc1522fc7
--- /dev/null
+++ b/scalelsd/ssl/datasets/synthetic_dataset.py
@@ -0,0 +1,746 @@
+"""
+This file implements the synthetic shape dataset object for pytorch
+"""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import os
+import math
+import h5py
+import pickle
+import torch
+import numpy as np
+import cv2
+from tqdm import tqdm
+from torchvision import transforms
+from torch.utils.data import Dataset
+import torch.utils.data.dataloader as torch_loader
+from ..config.project_config import Config as cfg
+from . import synthetic_util
+from .transforms import photometric_transforms as photoaug
+from .transforms import homographic_transforms as homoaug
+from ..misc.train_utils import parse_h5_data
+
+def synthetic_collate_fn(batch):
+ """ Customized collate_fn. """
+ batch_keys = ["image", "junction_map", "heatmap",
+ "valid_mask", "homography"]
+ list_keys = ["junctions", "line_map", "file_key"]
+ outputs = {}
+ for data_key in batch[0].keys():
+ batch_match = sum([_ in data_key for _ in batch_keys])
+ list_match = sum([_ in data_key for _ in list_keys])
+ # print(batch_match, list_match)
+ if batch_match > 0 and list_match == 0:
+ outputs[data_key] = torch_loader.default_collate([b[data_key]
+ for b in batch])
+ elif batch_match == 0 and list_match > 0:
+ outputs[data_key] = [b[data_key] for b in batch]
+ elif batch_match == 0 and list_match == 0:
+ continue
+ else:
+ raise ValueError(
+ "[Error] A key matches batch keys and list keys simultaneously.")
+ return outputs
+
+
+class SyntheticShapes(Dataset):
+ """ Dataset of synthetic shapes. """
+ # Initialize the dataset
+ def __init__(self, mode="train", config=None):
+ super(SyntheticShapes, self).__init__()
+ if not mode in ["train", "val", "test"]:
+ raise ValueError(
+ "[Error] Supported dataset modes are 'train', 'val', and 'test'.")
+ self.mode = mode
+
+ # Get configuration
+ if config is None:
+ self.config = self.get_default_config()
+ else:
+ self.config = config
+
+ # Set all available primitives
+ self.available_primitives = [
+ 'draw_checkerboard_multiseg',
+ 'draw_lines',
+ 'draw_polygon',
+ 'draw_multiple_polygons',
+ 'draw_star',
+ 'draw_stripes_multiseg',
+ 'draw_cube',
+ 'gaussian_noise'
+ ]
+ # Some cache setting
+ self.dataset_name = self.get_dataset_name()
+ self.cache_name = self.get_cache_name()
+ self.cache_path = cfg.synthetic_cache_path
+
+ # Check if export dataset exists
+ print("===============================================")
+ self.filename_dataset, self.datapoints = self.construct_dataset()
+ self.print_dataset_info()
+
+
+ self.dataset_length = len(self.datapoints)
+ # Initialize h5 file handle
+ self.dataset_path = os.path.join(cfg.synthetic_dataroot, self.dataset_name + ".h5")
+
+ # Fix the random seed for torch and numpy in testing mode
+ if ((self.mode == "val" or self.mode == "test")
+ and self.config["add_augmentation_to_all_splits"]):
+ seed = self.config.get("test_augmentation_seed", 200)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ # For CuDNN
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+ self.data_prefetch = {}
+ # # if 1 == 1:
+ # self.data_prefetch = []
+ # for i in range(len(self.datapoints)):
+ # if i % 100 == 0:
+ # print('loading %d/%d'%(i,len(self.datapoints)))
+ # dp = self.datapoints[i]
+ # with h5py.File(self.dataset_path, "r", swmr=True) as reader:
+ # data = self.get_data_from_datapoint(dp, reader)
+ # self.data_prefetch.append(data)
+
+ ##########################################
+ ## Dataset construction related methods ##
+ ##########################################
+ def construct_dataset(self):
+ """ Dataset constructor. """
+ # Check if the filename cache exists
+ # If cache exists, load from cache
+ if self._check_dataset_cache():
+ print("[Info]: Found filename cache at ...")
+ print("\t Load filename cache...")
+ filename_dataset, datapoints = self.get_filename_dataset_from_cache()
+ print("\t Check if all file exists...")
+ # If all file exists, continue
+ if self._check_file_existence(filename_dataset):
+ print("\t All files exist!")
+ # If not, need to re-export the synthetic dataset
+ else:
+ print("\t Some files are missing. Re-export the synthetic shape dataset.")
+ self.export_synthetic_shapes()
+ print("\t Initialize filename dataset")
+ filename_dataset, datapoints = self.get_filename_dataset()
+ print("\t Create filename dataset cache...")
+ self.create_filename_dataset_cache(filename_dataset,
+ datapoints)
+
+ # If not, initialize dataset from scratch
+ else:
+ print("[Info]: Can't find filename cache ...")
+ print("\t First check export dataset exists.")
+ # If export dataset exists, then just update the filename_dataset
+ if self._check_export_dataset():
+ print("\t Synthetic dataset exists. Initialize the dataset ...")
+
+ # If export dataset does not exist, export from scratch
+ else:
+ print("\t Synthetic dataset does not exist. Export the synthetic dataset.")
+ self.export_synthetic_shapes()
+ print("\t Initialize filename dataset")
+
+ filename_dataset, datapoints = self.get_filename_dataset()
+ print("\t Create filename dataset cache...")
+ self.create_filename_dataset_cache(filename_dataset, datapoints)
+
+ return filename_dataset, datapoints
+
+ def get_cache_name(self):
+ """ Get cache name from dataset config / default config. """
+ if self.config["dataset_name"] is None:
+ dataset_name = self.default_config["dataset_name"] + "_%s" % self.mode
+ else:
+ dataset_name = self.config["dataset_name"] + "_%s" % self.mode
+ if self.config.get('alias',None):
+ dataset_name = dataset_name + '-%s'%self.config['alias']
+
+ # Compose cache name
+ cache_name = dataset_name + "_cache.pkl"
+
+ return cache_name
+
+ def get_dataset_name(self):
+ """Get dataset name from dataset config / default config. """
+ if self.config["dataset_name"] is None:
+ dataset_name = self.default_config["dataset_name"] + "_%s" % self.mode
+ else:
+ dataset_name = self.config["dataset_name"] + "_%s" % self.mode
+ if self.config.get('alias',None):
+ dataset_name = dataset_name + '-%s'%self.config['alias']
+
+ return dataset_name
+
+ def get_filename_dataset_from_cache(self):
+ """ Get filename dataset from cache. """
+ # Load from the pkl cache
+ cache_file_path = os.path.join(self.cache_path, self.cache_name)
+ with open(cache_file_path, "rb") as f:
+ data = pickle.load(f)
+
+ return data["filename_dataset"], data["datapoints"]
+
+ def get_filename_dataset(self):
+ """ Get filename dataset from scratch. """
+ # Path to the exported dataset
+ dataset_path = os.path.join(cfg.synthetic_dataroot,
+ self.dataset_name + ".h5")
+
+ filename_dataset = {}
+ datapoints = []
+ # Open the h5 dataset
+ with h5py.File(dataset_path, "r") as f:
+ # Iterate through all the primitives
+ for prim_name in f.keys():
+ filenames = sorted(f[prim_name].keys())
+ filenames_full = [os.path.join(prim_name, _)
+ for _ in filenames]
+
+ filename_dataset[prim_name] = filenames_full
+ datapoints += filenames_full
+
+ return filename_dataset, datapoints
+
+ def create_filename_dataset_cache(self, filename_dataset, datapoints):
+ """ Create filename dataset cache for faster initialization. """
+ # Check cache path exists
+ if not os.path.exists(self.cache_path):
+ os.makedirs(self.cache_path)
+
+ cache_file_path = os.path.join(self.cache_path, self.cache_name)
+ data = {
+ "filename_dataset": filename_dataset,
+ "datapoints": datapoints
+ }
+ with open(cache_file_path, "wb") as f:
+ pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)
+
+ def export_synthetic_shapes(self):
+ """ Export synthetic shapes to disk. """
+ # Set the global random state for data generation
+ synthetic_util.set_random_state(np.random.RandomState(
+ self.config["generation"]["random_seed"]))
+ # Define the export path
+ dataset_path = os.path.join(cfg.synthetic_dataroot,
+ self.dataset_name + ".h5")
+
+ # Open h5py file
+ with h5py.File(dataset_path, "w", libver="latest") as f:
+ # Iterate through all types of shape
+ primitives = self.parse_drawing_primitives(
+ self.config["primitives"])
+ split_size = self.config["generation"]["split_sizes"][self.mode]
+ for prim in primitives:
+ # Create h5 group
+ group = f.create_group(prim)
+ # Export single primitive
+ self.export_single_primitive(prim, split_size, group)
+
+ f.swmr_mode = True
+
+ def export_single_primitive(self, primitive, split_size, group):
+ """ Export single primitive. """
+ # Check if the primitive is valid or not
+ if primitive not in self.available_primitives:
+ raise ValueError(
+ "[Error]: %s is not a supported primitive" % primitive)
+ # Set the random seed
+ synthetic_util.set_random_state(np.random.RandomState(
+ self.config["generation"]["random_seed"]))
+
+ # Generate shapes
+ print("\t Generating %s ..." % primitive)
+ # folder = f'/home/kezeran/code/hawpv4-dev/data-ssl/synthetic_shapes/{primitive}'
+ # os.makedirs(folder, exist_ok=True)
+ for idx in tqdm(range(split_size), ascii=True):
+ # Generate background image
+ image = synthetic_util.generate_background(
+ self.config['generation']['image_size'],
+ **self.config['generation']['params']['generate_background'])
+ image = np.zeros(self.config['generation']['image_size'], dtype=np.uint8)
+
+ # Generate points
+ drawing_func = getattr(synthetic_util, primitive)
+ kwarg = self.config["generation"]["params"].get(primitive, {})
+
+ # Get min_len and min_label_len
+ min_len = self.config["generation"]["min_len"]
+ min_label_len = self.config["generation"]["min_label_len"]
+
+ # Some only take min_label_len, and gaussian noises take nothing
+ if primitive in ["draw_lines", "draw_polygon",
+ "draw_multiple_polygons", "draw_star"]:
+ data = drawing_func(image, min_len=min_len,
+ min_label_len=min_label_len, **kwarg)
+ elif primitive in ["draw_checkerboard_multiseg",
+ "draw_stripes_multiseg", "draw_cube"]:
+ data = drawing_func(image, min_label_len=min_label_len,
+ **kwarg)
+ else:
+ data = drawing_func(image, **kwarg)
+
+ # Convert the data
+ if data["points"] is not None:
+ points = np.flip(data["points"], axis=1).astype('float')
+ line_map = data["line_map"].astype(np.int32)
+ else:
+ points = np.zeros([0, 2]).astype('float')
+ line_map = np.zeros([0, 0]).astype(np.int32)
+
+ # Post-processing
+ blur_size = self.config["preprocessing"]["blur_size"]
+ image = cv2.GaussianBlur(image, (blur_size, blur_size), 0)
+
+ # Resize the image and the point location.
+ points = (points
+ * np.array(self.config['preprocessing']['resize'],
+ 'float')
+ / np.array(self.config['generation']['image_size'],
+ 'float'))
+ image = cv2.resize(
+ image, tuple(self.config['preprocessing']['resize'][::-1]),
+ interpolation=cv2.INTER_LINEAR)
+ image = np.array(image, dtype=np.uint8)
+
+ # Generate the line heatmap after post-processing
+ junctions = np.flip(np.round(points).astype(np.int32), axis=1)
+ heatmap = (synthetic_util.get_line_heatmap(
+ junctions, line_map,
+ size=image.shape) * 255.).astype(np.uint8)
+
+ # Record the data in group
+ num_pad = math.ceil(math.log10(split_size)) + 1
+ file_key_name = self.get_padded_filename(num_pad, idx)
+ file_group = group.create_group(file_key_name)
+
+ # save_path = f'{folder}/{idx}.png'
+ # cv2.imwrite(save_path, image)
+ # Store data
+ file_group.create_dataset("points", data=points,
+ compression="gzip")
+ file_group.create_dataset("image", data=image,
+ compression="gzip")
+ file_group.create_dataset("line_map", data=line_map,
+ compression="gzip")
+ file_group.create_dataset("heatmap", data=heatmap,
+ compression="gzip")
+
+ def get_default_config(self):
+ """ Get default configuration of the dataset. """
+ # Initialize the default configuration
+ self.default_config = {
+ "dataset_name": "synthetic_shape",
+ "primitives": "all",
+ "add_augmentation_to_all_splits": False,
+ # Shape generation configuration
+ "generation": {
+ "split_sizes": {'train': 10000, 'val': 400, 'test': 500},
+ "random_seed": 10,
+ "image_size": [960, 1280],
+ "min_len": 0.09,
+ "min_label_len": 0.1,
+ 'params': {
+ 'generate_background': {
+ 'min_kernel_size': 150, 'max_kernel_size': 500,
+ 'min_rad_ratio': 0.02, 'max_rad_ratio': 0.031},
+ 'draw_stripes': {'transform_params': (0.1, 0.1)},
+ 'draw_multiple_polygons': {'kernel_boundaries': (50, 100)}
+ },
+ },
+ # Date preprocessing configuration.
+ "preprocessing": {
+ "resize": [240, 320],
+ "blur_size": 11
+ },
+ 'augmentation': {
+ 'photometric': {
+ 'enable': False,
+ 'primitives': 'all',
+ 'params': {},
+ 'random_order': True,
+ },
+ 'homographic': {
+ 'enable': False,
+ 'params': {},
+ 'valid_border_margin': 0,
+ },
+ }
+ }
+
+ return self.default_config
+
+ def parse_drawing_primitives(self, names):
+ """ Parse the primitives in config to list of primitive names. """
+ if names == "all":
+ p = self.available_primitives
+ else:
+ if isinstance(names, list):
+ p = names
+ else:
+ p = [names]
+
+ assert set(p) <= set(self.available_primitives)
+
+ return p
+
+ @staticmethod
+ def get_padded_filename(num_pad, idx):
+ """ Get the padded filename using adaptive padding. """
+ file_len = len("%d" % (idx))
+ filename = "0" * (num_pad - file_len) + "%d" % (idx)
+
+ return filename
+
+ def print_dataset_info(self):
+ """ Print dataset info. """
+ print("\t ---------Summary------------------")
+ print("\t Dataset mode: \t\t %s" % self.mode)
+ print("\t Number of primitive: \t %d" % len(self.filename_dataset.keys()))
+ print("\t Number of data: \t %d" % len(self.datapoints))
+ print("\t ----------------------------------")
+
+ #########################
+ ## Pytorch related API ##
+ #########################
+ def get_data_from_datapoint(self, datapoint, reader=None):
+ """ Get data given the datapoint
+ (keyname of the h5 dataset e.g. "draw_lines/0000.h5"). """
+ # Check if the datapoint is valid
+ if not datapoint in self.datapoints:
+ raise ValueError(
+ "[Error] The specified datapoint is not in available datapoints.")
+
+ # Get data from h5 dataset
+ if reader is None:
+ raise ValueError(
+ "[Error] The reader must be provided in __getitem__.")
+ else:
+ data = reader[datapoint]
+
+ return parse_h5_data(data)
+
+ def get_data_from_signature(self, primitive_name, index):
+ """ Get data given the primitive name and index ("draw_lines", 10) """
+ # Check the primitive name and index
+ self._check_primitive_and_index(primitive_name, index)
+
+ # Get the datapoint from filename dataset
+ datapoint = self.filename_dataset[primitive_name][index]
+
+ return self.get_data_from_datapoint(datapoint)
+
+ def parse_transforms(self, names, all_transforms):
+ trans = all_transforms if (names == 'all') \
+ else (names if isinstance(names, list) else [names])
+ assert set(trans) <= set(all_transforms)
+ return trans
+
+ def get_photo_transform(self):
+ """ Get list of photometric transforms (according to the config). """
+ # Get the photometric transform config
+ photo_config = self.config["augmentation"]["photometric"]
+ if not photo_config["enable"]:
+ raise ValueError(
+ "[Error] Photometric augmentation is not enabled.")
+
+ # Parse photometric transforms
+ trans_lst = self.parse_transforms(photo_config["primitives"],
+ photoaug.available_augmentations)
+ trans_config_lst = [photo_config["params"].get(p, {})
+ for p in trans_lst]
+
+ # List of photometric augmentation
+ photometric_trans_lst = [
+ getattr(photoaug, trans)(**conf) \
+ for (trans, conf) in zip(trans_lst, trans_config_lst)
+ ]
+
+ return photometric_trans_lst
+
+ def get_homo_transform(self):
+ """ Get homographic transforms (according to the config). """
+ # Get homographic transforms for image
+ homo_config = self.config["augmentation"]["homographic"]["params"]
+ if not self.config["augmentation"]["homographic"]["enable"]:
+ raise ValueError(
+ "[Error] Homographic augmentation is not enabled")
+
+ # Parse the homographic transforms
+ # ToDo: use the shape from the config
+ image_shape = self.config["preprocessing"]["resize"]
+
+ # Compute the min_label_len from config
+ try:
+ min_label_tmp = self.config["generation"]["min_label_len"]
+ except:
+ min_label_tmp = None
+
+ # float label len => fraction
+ if isinstance(min_label_tmp, float): # Skip if not provided
+ min_label_len = min_label_tmp * min(image_shape)
+ # int label len => length in pixel
+ elif isinstance(min_label_tmp, int):
+ scale_ratio = (self.config["preprocessing"]["resize"]
+ / self.config["generation"]["image_size"][0])
+ min_label_len = (self.config["generation"]["min_label_len"]
+ * scale_ratio)
+ # if none => no restriction
+ else:
+ min_label_len = 0
+
+ # Initialize the transform
+ homographic_trans = homoaug.homography_transform(
+ image_shape, homo_config, 0, min_label_len)
+
+ return homographic_trans
+
+ @staticmethod
+ def junc_to_junc_map(junctions, image_size):
+ """ Convert junction points to junction maps. """
+ junctions = np.round(junctions).astype(np.int32)
+
+ # Clip the boundary by image size
+ junctions[:, 0] = np.clip(junctions[:, 0], 0., image_size[0]-1)
+ junctions[:, 1] = np.clip(junctions[:, 1], 0., image_size[1]-1)
+
+ # Create junction map
+ junc_map = np.zeros([image_size[0], image_size[1]])
+ junc_map[junctions[:, 0], junctions[:, 1]] = 1
+
+ return junc_map[..., None].astype(np.int32)
+
+ def train_preprocessing(self, data, disable_homoaug=False):
+ """ Training preprocessing. """
+ # Fetch corresponding entries
+ image = data["image"]
+ junctions = data["points"]
+ line_map = data["line_map"]
+ heatmap = data["heatmap"]
+ image_size = image.shape[:2]
+
+ # Resize the image before the photometric and homographic transforms
+ # Check if we need to do the resizing
+ if not(list(image.shape) == self.config["preprocessing"]["resize"]):
+ # Resize the image and the point location.
+ size_old = list(image.shape)
+ image = cv2.resize(
+ image, tuple(self.config['preprocessing']['resize'][::-1]),
+ interpolation=cv2.INTER_LINEAR)
+ image = np.array(image, dtype=np.uint8)
+
+ junctions = (
+ junctions
+ * np.array(self.config['preprocessing']['resize'], float)
+ / np.array(size_old, float))
+
+ # Generate the line heatmap after post-processing
+ junctions_xy = np.flip(np.round(junctions).astype(np.int32),
+ axis=1)
+ heatmap = synthetic_util.get_line_heatmap(junctions_xy, line_map,
+ size=image.shape)
+ heatmap = (heatmap * 255.).astype(np.uint8)
+
+ # Update image size
+ image_size = image.shape[:2]
+
+ # Declare default valid mask (all ones)
+ valid_mask = np.ones(image_size)
+
+ # Check if we need to apply augmentations
+ # In training mode => yes.
+ # In homography adaptation mode (export mode) => No
+ # Check photometric augmentation
+
+ # import time
+ # start = time.time()
+ if self.config["augmentation"]["photometric"]["enable"]:
+ photo_trans_lst = self.get_photo_transform()
+ ### Image transform ###
+ np.random.shuffle(photo_trans_lst)
+ image_transform = transforms.Compose(
+ photo_trans_lst + [photoaug.normalize_image()])
+ else:
+ image_transform = photoaug.normalize_image()
+ # import pdb; pdb.set_trace()
+ image = image_transform(image)
+
+ # Initialize the empty output dict
+ outputs = {}
+ # Convert to tensor and return the results
+ to_tensor = transforms.ToTensor()
+ # Check homographic augmentation
+ if (self.config["augmentation"]["homographic"]["enable"]
+ and disable_homoaug == False):
+ homo_trans = self.get_homo_transform()
+ # Perform homographic transform
+ homo_outputs = homo_trans(image, junctions, line_map)
+
+ # Record the warped results
+ junctions = homo_outputs["junctions"] # Should be HW format
+ image = homo_outputs["warped_image"]
+ line_map = homo_outputs["line_map"]
+ # heatmap = homo_outputs["warped_heatmap"]
+ valid_mask = homo_outputs["valid_mask"] # Same for pos and neg
+ homography_mat = homo_outputs["homo"]
+
+ # Optionally put warpping information first.
+ outputs["homography_mat"] = to_tensor(
+ homography_mat).to(torch.float32)[0, ...]
+ # end = time.time() - start
+ # print('photometric transform time: ', end)
+
+ # junction_map = self.junc_to_junc_map(junctions, image_size)
+
+ outputs.update({
+ "image": to_tensor(image),
+ "junctions": to_tensor(np.ascontiguousarray(
+ junctions).copy()).to(torch.float32)[0, ...],
+ # "junction_map": to_tensor(junction_map).to(torch.int),
+ "line_map": to_tensor(line_map).to(torch.int32)[0, ...],
+ # "heatmap": to_tensor(heatmap).to(torch.int32),
+ "valid_mask": to_tensor(valid_mask).to(torch.int32),
+ })
+
+ return outputs
+
+ def test_preprocessing(self, data):
+ """ Test preprocessing. """
+ # Fetch corresponding entries
+ image = data["image"]
+ points = data["points"]
+ line_map = data["line_map"]
+ heatmap = data["heatmap"]
+ image_size = image.shape[:2]
+
+ # Resize the image before the photometric and homographic transforms
+ if not (list(image.shape) == self.config["preprocessing"]["resize"]):
+ # Resize the image and the point location.
+ size_old = list(image.shape)
+ image = cv2.resize(
+ image, tuple(self.config['preprocessing']['resize'][::-1]),
+ interpolation=cv2.INTER_LINEAR)
+ image = np.array(image, dtype=np.uint8)
+
+ points = (points
+ * np.array(self.config['preprocessing']['resize'],
+ float)
+ / np.array(size_old, float))
+
+ # Generate the line heatmap after post-processing
+ junctions = np.flip(np.round(points).astype(np.int32), axis=1)
+ heatmap = synthetic_util.get_line_heatmap(junctions, line_map,
+ size=image.shape)
+ heatmap = (heatmap * 255.).astype(np.uint8)
+
+ # Update image size
+ image_size = image.shape[:2]
+
+ ### image transform ###
+ image_transform = photoaug.normalize_image()
+ image = image_transform(image)
+
+ ### joint transform ###
+ junction_map = self.junc_to_junc_map(points, image_size)
+ to_tensor = transforms.ToTensor()
+ image = to_tensor(image)
+ junctions = to_tensor(points)
+ junction_map = to_tensor(junction_map).to(torch.int)
+ line_map = to_tensor(line_map)
+ heatmap = to_tensor(heatmap)
+ valid_mask = to_tensor(np.ones(image_size)).to(torch.int32)
+
+ return {
+ "image": image,
+ "junctions": junctions,
+ "junction_map": junction_map,
+ "line_map": line_map,
+ "heatmap": heatmap,
+ "valid_mask": valid_mask
+ }
+
+ def __getitem__(self, index):
+ datapoint = self.datapoints[index]
+ with h5py.File(self.dataset_path, "r", swmr=True) as reader:
+ data = self.get_data_from_datapoint(datapoint, reader)
+
+ edges = np.stack(data['line_map'].nonzero()).transpose()
+ junctions = data['points'][:,::-1]
+ # lines = junctions[edges].reshape(-1,4)
+ # image = data['image']
+
+ # import matplotlib.pyplot as plt
+ # plt.imshow(image)
+ # plt.plot([lines[:,0],lines[:,2]],[lines[:,1],lines[:,3]],'r-')
+ # plt.plot(junctions[:,0],junctions[:,1],'b.')
+ # plt.show()
+ # Apply different transforms in different mod.
+ if (self.mode == "train"
+ or self.config["add_augmentation_to_all_splits"]):
+ return_type = self.config.get("return_type", "single")
+ # data = self.test_preprocessing(data)
+ data = self.train_preprocessing(data)
+ else:
+ data = self.test_preprocessing(data)
+
+ return data
+
+ def __len__(self):
+ # return len(self.datapoints)
+ return self.dataset_length
+
+ ########################
+ ## Some other methods ##
+ ########################
+ def _check_dataset_cache(self):
+ """ Check if dataset cache exists. """
+ cache_file_path = os.path.join(self.cache_path, self.cache_name)
+ if os.path.exists(cache_file_path):
+ return True
+ else:
+ return False
+
+ def _check_export_dataset(self):
+ """ Check if exported dataset exists. """
+ dataset_path = os.path.join(cfg.synthetic_dataroot, self.dataset_name)
+ if os.path.exists(dataset_path) and len(os.listdir(dataset_path)) > 0:
+ return True
+ else:
+ return False
+
+ def _check_file_existence(self, filename_dataset):
+ """ Check if all exported file exists. """
+ # Path to the exported dataset
+ dataset_path = os.path.join(cfg.synthetic_dataroot,
+ self.dataset_name + ".h5")
+
+ flag = True
+ # Open the h5 dataset
+ with h5py.File(dataset_path, "r") as f:
+ # Iterate through all the primitives
+ for prim_name in f.keys():
+ if (len(filename_dataset[prim_name])
+ != len(f[prim_name].keys())):
+ flag = False
+
+ return flag
+
+ def _check_primitive_and_index(self, primitive, index):
+ """ Check if the primitve and index are valid. """
+ # Check primitives
+ if not primitive in self.available_primitives:
+ raise ValueError(
+ "[Error] The primitive is not in available primitives.")
+
+ prim_len = len(self.filename_dataset[primitive])
+ # Check the index
+ if not index < prim_len:
+ raise ValueError(
+ "[Error] The index exceeds the total file counts %d for %s"
+ % (prim_len, primitive))
diff --git a/scalelsd/ssl/datasets/synthetic_util.py b/scalelsd/ssl/datasets/synthetic_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8cd9917b7c2504718159d133037843034ffb85c
--- /dev/null
+++ b/scalelsd/ssl/datasets/synthetic_util.py
@@ -0,0 +1,1243 @@
+"""
+Code adapted from https://github.com/rpautrat/SuperPoint
+Module used to generate geometrical synthetic shapes
+"""
+import math
+import cv2 as cv
+import numpy as np
+import shapely.geometry
+from itertools import combinations
+
+random_state = np.random.RandomState(None)
+
+
+def set_random_state(state):
+ global random_state
+ random_state = state
+
+
+def get_random_color(background_color):
+ """ Output a random scalar in grayscale with a least a small contrast
+ with the background color. """
+ color = random_state.randint(256)
+ if abs(color - background_color) < 30: # not enough contrast
+ color = (color + 128) % 256
+ return color
+
+
+def get_different_color(previous_colors, min_dist=50, max_count=20):
+ """ Output a color that contrasts with the previous colors.
+ Parameters:
+ previous_colors: np.array of the previous colors
+ min_dist: the difference between the new color and
+ the previous colors must be at least min_dist
+ max_count: maximal number of iterations
+ """
+ color = random_state.randint(256)
+ count = 0
+ while np.any(np.abs(previous_colors - color) < min_dist) and count < max_count:
+ count += 1
+ color = random_state.randint(256)
+ return color
+
+
+def add_salt_and_pepper(img):
+ """ Add salt and pepper noise to an image. """
+ noise = np.zeros((img.shape[0], img.shape[1]), dtype=np.uint8)
+ cv.randu(noise, 0, 255)
+ black = noise < 30
+ white = noise > 225
+ img[white > 0] = 255
+ img[black > 0] = 0
+ cv.blur(img, (5, 5), img)
+ return np.empty((0, 2), dtype=np.int32)
+
+
+def generate_background(size=(960, 1280), nb_blobs=100, min_rad_ratio=0.01,
+ max_rad_ratio=0.05, min_kernel_size=50,
+ max_kernel_size=300):
+ """ Generate a customized background image.
+ Parameters:
+ size: size of the image
+ nb_blobs: number of circles to draw
+ min_rad_ratio: the radius of blobs is at least min_rad_size * max(size)
+ max_rad_ratio: the radius of blobs is at most max_rad_size * max(size)
+ min_kernel_size: minimal size of the kernel
+ max_kernel_size: maximal size of the kernel
+ """
+ img = np.zeros(size, dtype=np.uint8)
+ dim = max(size)
+ cv.randu(img, 0, 255)
+ cv.threshold(img, random_state.randint(256), 255, cv.THRESH_BINARY, img)
+ background_color = int(np.mean(img))
+ blobs = np.concatenate(
+ [random_state.randint(0, size[1], size=(nb_blobs, 1)),
+ random_state.randint(0, size[0], size=(nb_blobs, 1))], axis=1)
+ for i in range(nb_blobs):
+ col = get_random_color(background_color)
+ cv.circle(img, (blobs[i][0], blobs[i][1]),
+ np.random.randint(int(dim * min_rad_ratio),
+ int(dim * max_rad_ratio)),
+ col, -1)
+ kernel_size = random_state.randint(min_kernel_size, max_kernel_size)
+ cv.blur(img, (kernel_size, kernel_size), img)
+ return img
+
+
+def generate_custom_background(size, background_color, nb_blobs=3000,
+ kernel_boundaries=(50, 100)):
+ """ Generate a customized background to fill the shapes.
+ Parameters:
+ background_color: average color of the background image
+ nb_blobs: number of circles to draw
+ kernel_boundaries: interval of the possible sizes of the kernel
+ """
+ img = np.zeros(size, dtype=np.uint8)
+ img = img + get_random_color(background_color)
+ blobs = np.concatenate(
+ [np.random.randint(0, size[1], size=(nb_blobs, 1)),
+ np.random.randint(0, size[0], size=(nb_blobs, 1))], axis=1)
+ for i in range(nb_blobs):
+ col = get_random_color(background_color)
+ cv.circle(img, (blobs[i][0], blobs[i][1]),
+ np.random.randint(20), col, -1)
+ kernel_size = np.random.randint(kernel_boundaries[0],
+ kernel_boundaries[1])
+ cv.blur(img, (kernel_size, kernel_size), img)
+ return img
+
+
+def final_blur(img, kernel_size=(5, 5)):
+ """ Gaussian blur applied to an image.
+ Parameters:
+ kernel_size: size of the kernel
+ """
+ cv.GaussianBlur(img, kernel_size, 0, img)
+
+
+def ccw(A, B, C, dim):
+ """ Check if the points are listed in counter-clockwise order. """
+ if dim == 2: # only 2 dimensions
+ return((C[:, 1] - A[:, 1]) * (B[:, 0] - A[:, 0])
+ > (B[:, 1] - A[:, 1]) * (C[:, 0] - A[:, 0]))
+ else: # dim should be equal to 3
+ return((C[:, 1, :] - A[:, 1, :])
+ * (B[:, 0, :] - A[:, 0, :])
+ > (B[:, 1, :] - A[:, 1, :])
+ * (C[:, 0, :] - A[:, 0, :]))
+
+
+def intersect(A, B, C, D, dim):
+ """ Return true if line segments AB and CD intersect """
+ return np.any((ccw(A, C, D, dim) != ccw(B, C, D, dim)) &
+ (ccw(A, B, C, dim) != ccw(A, B, D, dim)))
+
+
+def keep_points_inside(points, size):
+ """ Keep only the points whose coordinates are inside the dimensions of
+ the image of size 'size' """
+ mask = (points[:, 0] >= 0) & (points[:, 0] < size[1]) &\
+ (points[:, 1] >= 0) & (points[:, 1] < size[0])
+ return points[mask, :]
+
+
+def get_unique_junctions(segments, min_label_len):
+ """ Get unique junction points from line segments. """
+ # Get all junctions from segments
+ junctions_all = np.concatenate((segments[:, :2], segments[:, 2:]), axis=0)
+ if junctions_all.shape[0] == 0:
+ junc_points = None
+ line_map = None
+
+ # Get all unique junction points
+ else:
+ junc_points = np.unique(junctions_all, axis=0)
+ # Generate line map from points and segments
+ line_map = get_line_map(junc_points, segments)
+
+ return junc_points, line_map
+
+
+def get_line_map(points: np.ndarray, segments: np.ndarray) -> np.ndarray:
+ """ Get line map given the points and segment sets. """
+ # create empty line map
+ num_point = points.shape[0]
+ line_map = np.zeros([num_point, num_point])
+
+ # Iterate through every segment
+ for idx in range(segments.shape[0]):
+ # Get the junctions from a single segement
+ seg = segments[idx, :]
+ junction1 = seg[:2]
+ junction2 = seg[2:]
+
+ # Get index
+ idx_junction1 = np.where((points == junction1).sum(axis=1) == 2)[0]
+ idx_junction2 = np.where((points == junction2).sum(axis=1) == 2)[0]
+
+ # label the corresponding entries
+ line_map[idx_junction1, idx_junction2] = 1
+ line_map[idx_junction2, idx_junction1] = 1
+
+ return line_map
+
+
+def get_line_heatmap(junctions, line_map, size=[480, 640], thickness=1):
+ """ Get line heat map from junctions and line map. """
+ # Make sure that the thickness is 1
+ if not isinstance(thickness, int):
+ thickness = int(thickness)
+
+ # If the junction points are not int => round them and convert to int
+ if not junctions.dtype == np.int32:
+ junctions = (np.round(junctions)).astype(np.int32)
+
+ # Initialize empty map
+ heat_map = np.zeros(size)
+
+ if junctions.shape[0] > 0: # If empty, just return zero map
+ # Iterate through all the junctions
+ for idx in range(junctions.shape[0]):
+ # if no connectivity, just skip it
+ if line_map[idx, :].sum() == 0:
+ continue
+ # Plot the line segment
+ else:
+ # Iterate through all the connected junctions
+ for idx2 in np.where(line_map[idx, :] == 1)[0]:
+ point1 = junctions[idx, :]
+ point2 = junctions[idx2, :]
+
+ # Draw line
+ cv.line(heat_map, tuple(point1), tuple(point2), 1., thickness)
+
+ return heat_map
+
+
+def draw_lines(img, nb_lines=10, min_len=32, min_label_len=32):
+ """ Draw random lines and output the positions of the pair of junctions
+ and line associativities.
+ Parameters:
+ nb_lines: maximal number of lines
+ """
+ # Set line number and points placeholder
+ num_lines = random_state.randint(1, nb_lines)
+ segments = np.empty((0, 4), dtype=np.int32)
+ points = np.empty((0, 2), dtype=np.int32)
+
+ background_color = int(np.mean(img))
+ min_dim = min(img.shape)
+
+ # Convert length constrain to pixel if given float number
+ if isinstance(min_len, float) and min_len <= 1.:
+ min_len = int(min_dim * min_len)
+ if isinstance(min_label_len, float) and min_label_len <= 1.:
+ min_label_len = int(min_dim * min_label_len)
+
+ # Generate lines one by one
+ for i in range(num_lines):
+ x1 = random_state.randint(img.shape[1])
+ y1 = random_state.randint(img.shape[0])
+ p1 = np.array([[x1, y1]])
+ x2 = random_state.randint(img.shape[1])
+ y2 = random_state.randint(img.shape[0])
+ p2 = np.array([[x2, y2]])
+
+ # Check the length of the line
+ line_length = np.sqrt(np.sum((p1 - p2) ** 2))
+ if line_length < min_len:
+ continue
+
+ # Check that there is no overlap
+ if intersect(segments[:, 0:2], segments[:, 2:4], p1, p2, 2):
+ continue
+
+ col = get_random_color(background_color)
+ thickness = random_state.randint(min_dim * 0.01, min_dim * 0.02)
+ cv.line(img, (x1, y1), (x2, y2), col, thickness)
+
+ # Only record the segments longer than min_label_len
+ seg_len = math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)
+ if seg_len >= min_label_len:
+ segments = np.concatenate([segments,
+ np.array([[x1, y1, x2, y2]])], axis=0)
+ points = np.concatenate([points,
+ np.array([[x1, y1], [x2, y2]])], axis=0)
+
+ # If no line is drawn, recursively call the function
+ if points.shape[0] == 0:
+ return draw_lines(img, nb_lines, min_len, min_label_len)
+
+ # Get the line associativity map
+ line_map = get_line_map(points, segments)
+
+ return {
+ "points": points,
+ "line_map": line_map
+ }
+
+
+def check_segment_len(segments, min_len=32):
+ """ Check if one of the segments is too short (True means too short). """
+ point1_vec = segments[:, :2]
+ point2_vec = segments[:, 2:]
+ diff = point1_vec - point2_vec
+
+ dist = np.sqrt(np.sum(diff ** 2, axis=1))
+ if np.any(dist < min_len):
+ return True
+ else:
+ return False
+
+
+def draw_polygon(img, max_sides=8, min_len=32, min_label_len=64):
+ """ Draw a polygon with a random number of corners and return the position
+ of the junctions + line map.
+ Parameters:
+ max_sides: maximal number of sides + 1
+ """
+ num_corners = random_state.randint(3, max_sides)
+ min_dim = min(img.shape[0], img.shape[1])
+ rad = max(random_state.rand() * min_dim / 2, min_dim / 10)
+ # Center of a circle
+ x = random_state.randint(rad, img.shape[1] - rad)
+ y = random_state.randint(rad, img.shape[0] - rad)
+
+ # Convert length constrain to pixel if given float number
+ if isinstance(min_len, float) and min_len <= 1.:
+ min_len = int(min_dim * min_len)
+ if isinstance(min_label_len, float) and min_label_len <= 1.:
+ min_label_len = int(min_dim * min_label_len)
+
+ # Sample num_corners points inside the circle
+ slices = np.linspace(0, 2 * math.pi, num_corners + 1)
+ angles = [slices[i] + random_state.rand() * (slices[i+1] - slices[i])
+ for i in range(num_corners)]
+ points = np.array(
+ [[int(x + max(random_state.rand(), 0.4) * rad * math.cos(a)),
+ int(y + max(random_state.rand(), 0.4) * rad * math.sin(a))]
+ for a in angles])
+
+ # Filter the points that are too close or that have an angle too flat
+ norms = [np.linalg.norm(points[(i-1) % num_corners, :]
+ - points[i, :]) for i in range(num_corners)]
+ mask = np.array(norms) > 0.01
+ points = points[mask, :]
+ num_corners = points.shape[0]
+ corner_angles = [angle_between_vectors(points[(i-1) % num_corners, :] -
+ points[i, :],
+ points[(i+1) % num_corners, :] -
+ points[i, :])
+ for i in range(num_corners)]
+ mask = np.array(corner_angles) < (2 * math.pi / 3)
+ points = points[mask, :]
+ num_corners = points.shape[0]
+
+ # Get junction pairs from points
+ segments = np.zeros([0, 4])
+ # Used to record all the segments no matter we are going to label it or not.
+ segments_raw = np.zeros([0, 4])
+ for idx in range(num_corners):
+ if idx == (num_corners - 1):
+ p1 = points[idx]
+ p2 = points[0]
+ else:
+ p1 = points[idx]
+ p2 = points[idx + 1]
+
+ segment = np.concatenate((p1, p2), axis=0)
+ # Only record the segments longer than min_label_len
+ seg_len = np.sqrt(np.sum((p1 - p2) ** 2))
+ if seg_len >= min_label_len:
+ segments = np.concatenate((segments, segment[None, ...]), axis=0)
+ segments_raw = np.concatenate((segments_raw, segment[None, ...]),
+ axis=0)
+
+ # If not enough corner, just regenerate one
+ if (num_corners < 3) or check_segment_len(segments_raw, min_len):
+ return draw_polygon(img, max_sides, min_len, min_label_len)
+
+ # Get junctions from segments
+ junctions_all = np.concatenate((segments[:, :2], segments[:, 2:]), axis=0)
+ if junctions_all.shape[0] == 0:
+ junc_points = None
+ line_map = None
+
+ else:
+ junc_points = np.unique(junctions_all, axis=0)
+
+ # Get the line map
+ line_map = get_line_map(junc_points, segments)
+
+ corners = points.reshape((-1, 1, 2))
+ col = get_random_color(int(np.mean(img)))
+ cv.fillPoly(img, [corners], col)
+
+ return {
+ "points": junc_points,
+ "line_map": line_map
+ }
+
+
+def overlap(center, rad, centers, rads):
+ """ Check that the circle with (center, rad)
+ doesn't overlap with the other circles. """
+ flag = False
+ for i in range(len(rads)):
+ if np.linalg.norm(center - centers[i]) < rad + rads[i]:
+ flag = True
+ break
+ return flag
+
+
+def angle_between_vectors(v1, v2):
+ """ Compute the angle (in rad) between the two vectors v1 and v2. """
+ v1_u = v1 / np.linalg.norm(v1)
+ v2_u = v2 / np.linalg.norm(v2)
+ return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0))
+
+
+def draw_multiple_polygons(img, max_sides=8, nb_polygons=30, min_len=32,
+ min_label_len=64, safe_margin=5, **extra):
+ """ Draw multiple polygons with a random number of corners
+ and return the junction points + line map.
+ Parameters:
+ max_sides: maximal number of sides + 1
+ nb_polygons: maximal number of polygons
+ """
+ segments = np.empty((0, 4), dtype=np.int32)
+ label_segments = np.empty((0, 4), dtype=np.int32)
+ centers = []
+ rads = []
+ points = np.empty((0, 2), dtype=np.int32)
+ background_color = int(np.mean(img))
+
+ min_dim = min(img.shape[0], img.shape[1])
+ # Convert length constrain to pixel if given float number
+ if isinstance(min_len, float) and min_len <= 1.:
+ min_len = int(min_dim * min_len)
+ if isinstance(min_label_len, float) and min_label_len <= 1.:
+ min_label_len = int(min_dim * min_label_len)
+ if isinstance(safe_margin, float) and safe_margin <= 1.:
+ safe_margin = int(min_dim * safe_margin)
+
+ # Sequentially generate polygons
+ for i in range(nb_polygons):
+ num_corners = random_state.randint(3, max_sides)
+ min_dim = min(img.shape[0], img.shape[1])
+
+ # Also add the real radius
+ rad = max(random_state.rand() * min_dim / 2, min_dim / 9)
+ rad_real = rad - safe_margin
+
+ # Center of a circle
+ x = random_state.randint(rad, img.shape[1] - rad)
+ y = random_state.randint(rad, img.shape[0] - rad)
+
+ # Sample num_corners points inside the circle
+ slices = np.linspace(0, 2 * math.pi, num_corners + 1)
+ angles = [slices[i] + random_state.rand() * (slices[i+1] - slices[i])
+ for i in range(num_corners)]
+
+ # Sample outer points and inner points
+ new_points = []
+ new_points_real = []
+ for a in angles:
+ x_offset = max(random_state.rand(), 0.4)
+ y_offset = max(random_state.rand(), 0.4)
+ new_points.append([int(x + x_offset * rad * math.cos(a)),
+ int(y + y_offset * rad * math.sin(a))])
+ new_points_real.append(
+ [int(x + x_offset * rad_real * math.cos(a)),
+ int(y + y_offset * rad_real * math.sin(a))])
+ new_points = np.array(new_points)
+ new_points_real = np.array(new_points_real)
+
+ # Filter the points that are too close or that have an angle too flat
+ norms = [np.linalg.norm(new_points[(i-1) % num_corners, :]
+ - new_points[i, :])
+ for i in range(num_corners)]
+ mask = np.array(norms) > 0.01
+ new_points = new_points[mask, :]
+ new_points_real = new_points_real[mask, :]
+
+ num_corners = new_points.shape[0]
+ corner_angles = [
+ angle_between_vectors(new_points[(i-1) % num_corners, :] -
+ new_points[i, :],
+ new_points[(i+1) % num_corners, :] -
+ new_points[i, :])
+ for i in range(num_corners)]
+ mask = np.array(corner_angles) < (2 * math.pi / 3)
+ new_points = new_points[mask, :]
+ new_points_real = new_points_real[mask, :]
+ num_corners = new_points.shape[0]
+
+ # Not enough corners
+ if num_corners < 3:
+ continue
+
+ # Segments for checking overlap (outer circle)
+ new_segments = np.zeros((1, 4, num_corners))
+ new_segments[:, 0, :] = [new_points[i][0] for i in range(num_corners)]
+ new_segments[:, 1, :] = [new_points[i][1] for i in range(num_corners)]
+ new_segments[:, 2, :] = [new_points[(i+1) % num_corners][0]
+ for i in range(num_corners)]
+ new_segments[:, 3, :] = [new_points[(i+1) % num_corners][1]
+ for i in range(num_corners)]
+
+ # Segments to record (inner circle)
+ new_segments_real = np.zeros((1, 4, num_corners))
+ new_segments_real[:, 0, :] = [new_points_real[i][0]
+ for i in range(num_corners)]
+ new_segments_real[:, 1, :] = [new_points_real[i][1]
+ for i in range(num_corners)]
+ new_segments_real[:, 2, :] = [
+ new_points_real[(i + 1) % num_corners][0]
+ for i in range(num_corners)]
+ new_segments_real[:, 3, :] = [
+ new_points_real[(i + 1) % num_corners][1]
+ for i in range(num_corners)]
+
+ # Check that the polygon will not overlap with pre-existing shapes
+ if intersect(segments[:, 0:2, None], segments[:, 2:4, None],
+ new_segments[:, 0:2, :], new_segments[:, 2:4, :],
+ 3) or overlap(np.array([x, y]), rad, centers, rads):
+ continue
+
+ # Check that the the edges of the polygon is not too short
+ if check_segment_len(new_segments_real, min_len):
+ continue
+
+ # If the polygon is valid, append it to the polygon set
+ centers.append(np.array([x, y]))
+ rads.append(rad)
+ new_segments = np.reshape(np.swapaxes(new_segments, 0, 2), (-1, 4))
+ segments = np.concatenate([segments, new_segments], axis=0)
+
+ # Only record the segments longer than min_label_len
+ new_segments_real = np.reshape(np.swapaxes(new_segments_real, 0, 2),
+ (-1, 4))
+ points1 = new_segments_real[:, :2]
+ points2 = new_segments_real[:, 2:]
+ seg_len = np.sqrt(np.sum((points1 - points2) ** 2, axis=1))
+ new_label_segment = new_segments_real[seg_len >= min_label_len, :]
+ label_segments = np.concatenate([label_segments, new_label_segment],
+ axis=0)
+
+ # Color the polygon with a custom background
+ corners = new_points_real.reshape((-1, 1, 2))
+ mask = np.zeros(img.shape, np.uint8)
+ custom_background = generate_custom_background(
+ img.shape, background_color, **extra)
+
+ cv.fillPoly(mask, [corners], 255)
+ locs = np.where(mask != 0)
+ img[locs[0], locs[1]] = custom_background[locs[0], locs[1]]
+ points = np.concatenate([points, new_points], axis=0)
+
+ # Get all junctions from label segments
+ junctions_all = np.concatenate(
+ (label_segments[:, :2], label_segments[:, 2:]), axis=0)
+ if junctions_all.shape[0] == 0:
+ junc_points = None
+ line_map = None
+
+ else:
+ junc_points = np.unique(junctions_all, axis=0)
+
+ # Generate line map from points and segments
+ line_map = get_line_map(junc_points, label_segments)
+
+ return {
+ "points": junc_points,
+ "line_map": line_map
+ }
+
+
+def draw_ellipses(img, nb_ellipses=20):
+ """ Draw several ellipses.
+ Parameters:
+ nb_ellipses: maximal number of ellipses
+ """
+ centers = np.empty((0, 2), dtype=np.int32)
+ rads = np.empty((0, 1), dtype=np.int32)
+ min_dim = min(img.shape[0], img.shape[1]) / 4
+ background_color = int(np.mean(img))
+ for i in range(nb_ellipses):
+ ax = int(max(random_state.rand() * min_dim, min_dim / 5))
+ ay = int(max(random_state.rand() * min_dim, min_dim / 5))
+ max_rad = max(ax, ay)
+ x = random_state.randint(max_rad, img.shape[1] - max_rad) # center
+ y = random_state.randint(max_rad, img.shape[0] - max_rad)
+ new_center = np.array([[x, y]])
+
+ # Check that the ellipsis will not overlap with pre-existing shapes
+ diff = centers - new_center
+ if np.any(max_rad > (np.sqrt(np.sum(diff * diff, axis=1)) - rads)):
+ continue
+ centers = np.concatenate([centers, new_center], axis=0)
+ rads = np.concatenate([rads, np.array([[max_rad]])], axis=0)
+
+ col = get_random_color(background_color)
+ angle = random_state.rand() * 90
+ cv.ellipse(img, (x, y), (ax, ay), angle, 0, 360, col, -1)
+ return np.empty((0, 2), dtype=np.int32)
+
+
+def draw_star(img, nb_branches=6, min_len=32, min_label_len=64):
+ """ Draw a star and return the junction points + line map.
+ Parameters:
+ nb_branches: number of branches of the star
+ """
+ num_branches = random_state.randint(3, nb_branches)
+ min_dim = min(img.shape[0], img.shape[1])
+ # Convert length constrain to pixel if given float number
+ if isinstance(min_len, float) and min_len <= 1.:
+ min_len = int(min_dim * min_len)
+ if isinstance(min_label_len, float) and min_label_len <= 1.:
+ min_label_len = int(min_dim * min_label_len)
+
+ thickness = random_state.randint(min_dim * 0.01, min_dim * 0.025)
+ rad = max(random_state.rand() * min_dim / 2, min_dim / 5)
+ x = random_state.randint(rad, img.shape[1] - rad)
+ y = random_state.randint(rad, img.shape[0] - rad)
+ # Sample num_branches points inside the circle
+ slices = np.linspace(0, 2 * math.pi, num_branches + 1)
+ angles = [slices[i] + random_state.rand() * (slices[i+1] - slices[i])
+ for i in range(num_branches)]
+ points = np.array(
+ [[int(x + max(random_state.rand(), 0.3) * rad * math.cos(a)),
+ int(y + max(random_state.rand(), 0.3) * rad * math.sin(a))]
+ for a in angles])
+ points = np.concatenate(([[x, y]], points), axis=0)
+
+ # Generate segments and check the length
+ segments = np.array([[x, y, _[0], _[1]] for _ in points[1:, :]])
+ if check_segment_len(segments, min_len):
+ return draw_star(img, nb_branches, min_len, min_label_len)
+
+ # Only record the segments longer than min_label_len
+ points1 = segments[:, :2]
+ points2 = segments[:, 2:]
+ seg_len = np.sqrt(np.sum((points1 - points2) ** 2, axis=1))
+ label_segments = segments[seg_len >= min_label_len, :]
+
+ # Get all junctions from label segments
+ junctions_all = np.concatenate(
+ (label_segments[:, :2], label_segments[:, 2:]), axis=0)
+ if junctions_all.shape[0] == 0:
+ junc_points = None
+ line_map = None
+
+ # Get all unique junction points
+ else:
+ junc_points = np.unique(junctions_all, axis=0)
+ # Generate line map from points and segments
+ line_map = get_line_map(junc_points, label_segments)
+
+ background_color = int(np.mean(img))
+ for i in range(1, num_branches + 1):
+ col = get_random_color(background_color)
+ cv.line(img, (points[0][0], points[0][1]),
+ (points[i][0], points[i][1]),
+ col, thickness)
+ return {
+ "points": junc_points,
+ "line_map": line_map
+ }
+
+
+def draw_checkerboard_multiseg(img, max_rows=7, max_cols=7,
+ transform_params=(0.05, 0.15),
+ min_label_len=64, seed=None):
+ """ Draw a checkerboard and output the junctions + line segments
+ Parameters:
+ max_rows: maximal number of rows + 1
+ max_cols: maximal number of cols + 1
+ transform_params: set the range of the parameters of the transformations
+ """
+ if seed is None:
+ global random_state
+ else:
+ random_state = np.random.RandomState(seed)
+
+ background_color = int(np.mean(img))
+
+ min_dim = min(img.shape)
+ if isinstance(min_label_len, float) and min_label_len <= 1.:
+ min_label_len = int(min_dim * min_label_len)
+ # Create the grid
+ rows = random_state.randint(3, max_rows) # number of rows
+ cols = random_state.randint(3, max_cols) # number of cols
+ s = min((img.shape[1] - 1) // cols, (img.shape[0] - 1) // rows)
+ x_coord = np.tile(range(cols + 1),
+ rows + 1).reshape(((rows + 1) * (cols + 1), 1))
+ y_coord = np.repeat(range(rows + 1),
+ cols + 1).reshape(((rows + 1) * (cols + 1), 1))
+
+ # points are the grid coordinates
+ points = s * np.concatenate([x_coord, y_coord], axis=1)
+
+ # Warp the grid using an affine transformation and an homography
+ alpha_affine = np.max(img.shape) * (
+ transform_params[0] + random_state.rand() * transform_params[1])
+ center_square = np.float32(img.shape) // 2
+ min_dim = min(img.shape)
+ square_size = min_dim // 3
+ pts1 = np.float32([center_square + square_size,
+ [center_square[0] + square_size,
+ center_square[1] - square_size],
+ center_square - square_size,
+ [center_square[0] - square_size,
+ center_square[1] + square_size]])
+ pts2 = pts1 + random_state.uniform(-alpha_affine, alpha_affine,
+ size=pts1.shape).astype(np.float32)
+ affine_transform = cv.getAffineTransform(pts1[:3], pts2[:3])
+ pts2 = pts1 + random_state.uniform(-alpha_affine / 2, alpha_affine / 2,
+ size=pts1.shape).astype(np.float32)
+ perspective_transform = cv.getPerspectiveTransform(pts1, pts2)
+
+ # Apply the affine transformation
+ points = np.transpose(np.concatenate(
+ (points, np.ones(((rows + 1) * (cols + 1), 1))), axis=1))
+ warped_points = np.transpose(np.dot(affine_transform, points))
+
+ # Apply the homography
+ warped_col0 = np.add(np.sum(np.multiply(
+ warped_points, perspective_transform[0, :2]), axis=1),
+ perspective_transform[0, 2])
+ warped_col1 = np.add(np.sum(np.multiply(
+ warped_points, perspective_transform[1, :2]), axis=1),
+ perspective_transform[1, 2])
+ warped_col2 = np.add(np.sum(np.multiply(
+ warped_points, perspective_transform[2, :2]), axis=1),
+ perspective_transform[2, 2])
+ warped_col0 = np.divide(warped_col0, warped_col2)
+ warped_col1 = np.divide(warped_col1, warped_col2)
+ warped_points = np.concatenate(
+ [warped_col0[:, None], warped_col1[:, None]], axis=1)
+ warped_points_float = warped_points.copy()
+ warped_points = warped_points.astype(int)
+
+ # Fill the rectangles
+ colors = np.zeros((rows * cols,), np.int32)
+
+ label_segments = []
+ for i in range(rows):
+ for j in range(cols):
+ # Get a color that contrast with the neighboring cells
+ if i == 0 and j == 0:
+ col = get_random_color(background_color)
+ else:
+ neighboring_colors = []
+ if i != 0:
+ neighboring_colors.append(colors[(i - 1) * cols + j])
+ if j != 0:
+ neighboring_colors.append(colors[i * cols + j - 1])
+ col = get_different_color(np.array(neighboring_colors))
+ colors[i * cols + j] = col
+ # Fill the cell
+ cv.fillConvexPoly(img, np.array(
+ [(warped_points[i * (cols + 1) + j, 0],
+ warped_points[i * (cols + 1) + j, 1]),
+ (warped_points[i * (cols + 1) + j + 1, 0],
+ warped_points[i * (cols + 1) + j + 1, 1]),
+ (warped_points[(i + 1) * (cols + 1) + j + 1, 0],
+ warped_points[(i + 1) * (cols + 1) + j + 1, 1]),
+ (warped_points[(i + 1) * (cols + 1) + j, 0],
+ warped_points[(i + 1) * (cols + 1) + j, 1])]), col)
+ line1 = np.concatenate([warped_points[i * (cols + 1) + j],warped_points[i * (cols + 1) + j+1]])
+ line2 = np.concatenate([warped_points[i * (cols + 1) + j+1],warped_points[(i + 1) * (cols + 1) + j + 1]])
+ line3 = np.concatenate([warped_points[(i + 1) * (cols + 1) + j + 1],warped_points[(i + 1) * (cols + 1) + j]])
+ line4 = np.concatenate([warped_points[(i + 1) * (cols + 1) + j],warped_points[i * (cols + 1) + j]])
+ lines = np.stack([line1,line2,line3,line4])
+ label_segments.append(lines)
+
+ label_segments = np.concatenate(label_segments)
+ """
+ label_segments = np.empty([0, 4], dtype=np.int32)
+ # Iterate through rows
+ for row_idx in range(rows + 1):
+ # Include all the combination of the junctions
+ # Iterate through all the combination of junction index in that row
+ multi_seg_lst = [
+ np.array([warped_points_float[id1, 0],
+ warped_points_float[id1, 1],
+ warped_points_float[id2, 0],
+ warped_points_float[id2, 1]])[None, ...]
+ for (id1, id2) in combinations(range(
+ row_idx * (cols + 1), (row_idx + 1) * (cols + 1), 1), 2)]
+ multi_seg = np.concatenate(multi_seg_lst, axis=0)
+ label_segments = np.concatenate((label_segments, multi_seg), axis=0)
+
+ # Iterate through columns
+ for col_idx in range(cols + 1): # for 5 columns, we will have 5 + 1 edges
+ # Include all the combination of the junctions
+ # Iterate throuhg all the combination of junction index in that column
+ multi_seg_lst = [
+ np.array([warped_points_float[id1, 0],
+ warped_points_float[id1, 1],
+ warped_points_float[id2, 0],
+ warped_points_float[id2, 1]])[None, ...]
+ for (id1, id2) in combinations(range(
+ col_idx, col_idx + ((rows + 1) * (cols + 1)), cols + 1), 2)]
+ multi_seg = np.concatenate(multi_seg_lst, axis=0)
+ label_segments = np.concatenate((label_segments, multi_seg), axis=0)
+ """
+ label_segments_filtered = np.zeros([0, 4])
+ # Define image boundary polygon (in x y manner)
+ image_poly = shapely.geometry.Polygon(
+ [[0, 0], [img.shape[1] - 1, 0], [img.shape[1] - 1, img.shape[0] - 1],
+ [0, img.shape[0] - 1]])
+ for idx in range(label_segments.shape[0]):
+ # Get the line segment
+ seg_raw = label_segments[idx, :]
+ seg = shapely.geometry.LineString([seg_raw[:2], seg_raw[2:]])
+
+ # The line segment is just inside the image.
+ if seg.intersection(image_poly) == seg:
+ label_segments_filtered = np.concatenate(
+ (label_segments_filtered, seg_raw[None, ...]), axis=0)
+
+ # Intersect with the image.
+ elif seg.intersects(image_poly):
+ # Check intersection
+ try:
+ p = np.array(seg.intersection(
+ image_poly).coords).reshape([-1, 4])
+ # If intersect with eact one point
+ except:
+ continue
+ segment = p
+ label_segments_filtered = np.concatenate(
+ (label_segments_filtered, segment), axis=0)
+
+ else:
+ continue
+
+ label_segments = np.round(label_segments_filtered).astype(np.int32)
+
+ # Only record the segments longer than min_label_len
+ points1 = label_segments[:, :2]
+ points2 = label_segments[:, 2:]
+ seg_len = np.sqrt(np.sum((points1 - points2) ** 2, axis=1))
+ label_segments = label_segments[seg_len >= min_label_len, :]
+
+ # Get all junctions from label segments
+ junc_points, line_map = get_unique_junctions(label_segments,
+ min_label_len)
+
+ # Draw lines on the boundaries of the board at random
+ nb_rows = random_state.randint(2, rows + 2)
+ nb_cols = random_state.randint(2, cols + 2)
+ thickness = random_state.randint(min_dim * 0.01, min_dim * 0.015)
+ for _ in range(nb_rows):
+ row_idx = random_state.randint(rows + 1)
+ col_idx1 = random_state.randint(cols + 1)
+ col_idx2 = random_state.randint(cols + 1)
+ col = get_random_color(background_color)
+ cv.line(img, (warped_points[row_idx * (cols + 1) + col_idx1, 0],
+ warped_points[row_idx * (cols + 1) + col_idx1, 1]),
+ (warped_points[row_idx * (cols + 1) + col_idx2, 0],
+ warped_points[row_idx * (cols + 1) + col_idx2, 1]),
+ col, thickness)
+ for _ in range(nb_cols):
+ col_idx = random_state.randint(cols + 1)
+ row_idx1 = random_state.randint(rows + 1)
+ row_idx2 = random_state.randint(rows + 1)
+ col = get_random_color(background_color)
+ cv.line(img, (warped_points[row_idx1 * (cols + 1) + col_idx, 0],
+ warped_points[row_idx1 * (cols + 1) + col_idx, 1]),
+ (warped_points[row_idx2 * (cols + 1) + col_idx, 0],
+ warped_points[row_idx2 * (cols + 1) + col_idx, 1]),
+ col, thickness)
+
+ # Keep only the points inside the image
+ points = keep_points_inside(warped_points, img.shape[:2])
+ return {
+ "points": junc_points,
+ "line_map": line_map
+ }
+
+
+def draw_stripes_multiseg(img, max_nb_cols=13, min_len=0.04, min_label_len=64,
+ transform_params=(0.05, 0.15), seed=None):
+ """ Draw stripes in a distorted rectangle
+ and output the junctions points + line map.
+ Parameters:
+ max_nb_cols: maximal number of stripes to be drawn
+ min_width_ratio: the minimal width of a stripe is
+ min_width_ratio * smallest dimension of the image
+ transform_params: set the range of the parameters of the transformations
+ """
+ # Set the optional random seed (most for debugging)
+ if seed is None:
+ global random_state
+ else:
+ random_state = np.random.RandomState(seed)
+
+ background_color = int(np.mean(img))
+ # Create the grid
+ board_size = (int(img.shape[0] * (1 + random_state.rand())),
+ int(img.shape[1] * (1 + random_state.rand())))
+
+ # Number of cols
+ col = random_state.randint(5, max_nb_cols)
+ cols = np.concatenate([board_size[1] * random_state.rand(col - 1),
+ np.array([0, board_size[1] - 1])], axis=0)
+ cols = np.unique(cols.astype(int))
+
+ # Remove the indices that are too close
+ min_dim = min(img.shape)
+
+ # Convert length constrain to pixel if given float number
+ if isinstance(min_len, float) and min_len <= 1.:
+ min_len = int(min_dim * min_len)
+ if isinstance(min_label_len, float) and min_label_len <= 1.:
+ min_label_len = int(min_dim * min_label_len)
+
+ cols = cols[(np.concatenate([cols[1:],
+ np.array([board_size[1] + min_len])],
+ axis=0) - cols) >= min_len]
+ # Update the number of cols
+ col = cols.shape[0] - 1
+ cols = np.reshape(cols, (col + 1, 1))
+ cols1 = np.concatenate([cols, np.zeros((col + 1, 1), np.int32)], axis=1)
+ cols2 = np.concatenate(
+ [cols, (board_size[0] - 1) * np.ones((col + 1, 1), np.int32)], axis=1)
+ points = np.concatenate([cols1, cols2], axis=0)
+
+ # Warp the grid using an affine transformation and a homography
+ alpha_affine = np.max(img.shape) * (
+ transform_params[0] + random_state.rand() * transform_params[1])
+ center_square = np.float32(img.shape) // 2
+ square_size = min(img.shape) // 3
+ pts1 = np.float32([center_square + square_size,
+ [center_square[0]+square_size,
+ center_square[1]-square_size],
+ center_square - square_size,
+ [center_square[0]-square_size,
+ center_square[1]+square_size]])
+ pts2 = pts1 + random_state.uniform(-alpha_affine, alpha_affine,
+ size=pts1.shape).astype(np.float32)
+ affine_transform = cv.getAffineTransform(pts1[:3], pts2[:3])
+ pts2 = pts1 + random_state.uniform(-alpha_affine / 2, alpha_affine / 2,
+ size=pts1.shape).astype(np.float32)
+ perspective_transform = cv.getPerspectiveTransform(pts1, pts2)
+
+ # Apply the affine transformation
+ points = np.transpose(np.concatenate((points,
+ np.ones((2 * (col + 1), 1))),
+ axis=1))
+ warped_points = np.transpose(np.dot(affine_transform, points))
+
+ # Apply the homography
+ warped_col0 = np.add(np.sum(np.multiply(
+ warped_points, perspective_transform[0, :2]), axis=1),
+ perspective_transform[0, 2])
+ warped_col1 = np.add(np.sum(np.multiply(
+ warped_points, perspective_transform[1, :2]), axis=1),
+ perspective_transform[1, 2])
+ warped_col2 = np.add(np.sum(np.multiply(
+ warped_points, perspective_transform[2, :2]), axis=1),
+ perspective_transform[2, 2])
+ warped_col0 = np.divide(warped_col0, warped_col2)
+ warped_col1 = np.divide(warped_col1, warped_col2)
+ warped_points = np.concatenate(
+ [warped_col0[:, None], warped_col1[:, None]], axis=1)
+ warped_points_float = warped_points.copy()
+ warped_points = warped_points.astype(int)
+
+ # Fill the rectangles and get the segments
+ color = get_random_color(background_color)
+ # segments_debug = np.zeros([0, 4])
+ for i in range(col):
+ # Fill the color
+ color = (color + 128 + random_state.randint(-30, 30)) % 256
+ cv.fillConvexPoly(img, np.array([(warped_points[i, 0],
+ warped_points[i, 1]),
+ (warped_points[i+1, 0],
+ warped_points[i+1, 1]),
+ (warped_points[i+col+2, 0],
+ warped_points[i+col+2, 1]),
+ (warped_points[i+col+1, 0],
+ warped_points[i+col+1, 1])]),
+ color)
+
+ segments = np.zeros([0, 4])
+ row = 1 # in stripes case
+ # Iterate through rows
+ for row_idx in range(row + 1):
+ # Include all the combination of the junctions
+ # Iterate through all the combination of junction index in that row
+ multi_seg_lst = [np.array(
+ [warped_points_float[id1, 0],
+ warped_points_float[id1, 1],
+ warped_points_float[id2, 0],
+ warped_points_float[id2, 1]])[None, ...]
+ for (id1, id2) in combinations(range(
+ row_idx * (col + 1), (row_idx + 1) * (col + 1), 1), 2)]
+ multi_seg = np.concatenate(multi_seg_lst, axis=0)
+ segments = np.concatenate((segments, multi_seg), axis=0)
+
+ # Iterate through columns
+ for col_idx in range(col + 1): # for 5 columns, we will have 5 + 1 edges.
+ # Include all the combination of the junctions
+ # Iterate throuhg all the combination of junction index in that column
+ multi_seg_lst = [np.array(
+ [warped_points_float[id1, 0],
+ warped_points_float[id1, 1],
+ warped_points_float[id2, 0],
+ warped_points_float[id2, 1]])[None, ...]
+ for (id1, id2) in combinations(range(
+ col_idx, col_idx + (row * col) + 2, col + 1), 2)]
+ multi_seg = np.concatenate(multi_seg_lst, axis=0)
+ segments = np.concatenate((segments, multi_seg), axis=0)
+
+ # Select and refine the segments
+ segments_new = np.zeros([0, 4])
+ # Define image boundary polygon (in x y manner)
+ image_poly = shapely.geometry.Polygon(
+ [[0, 0], [img.shape[1]-1, 0], [img.shape[1]-1, img.shape[0]-1],
+ [0, img.shape[0]-1]])
+ for idx in range(segments.shape[0]):
+ # Get the line segment
+ seg_raw = segments[idx, :]
+ seg = shapely.geometry.LineString([seg_raw[:2], seg_raw[2:]])
+
+ # The line segment is just inside the image.
+ if seg.intersection(image_poly) == seg:
+ segments_new = np.concatenate(
+ (segments_new, seg_raw[None, ...]), axis=0)
+
+ # Intersect with the image.
+ elif seg.intersects(image_poly):
+ # Check intersection
+ try:
+ p = np.array(
+ seg.intersection(image_poly).coords).reshape([-1, 4])
+ # If intersect at exact one point, just continue.
+ except:
+ continue
+ segment = p
+ segments_new = np.concatenate((segments_new, segment), axis=0)
+
+ else:
+ continue
+
+ segments = (np.round(segments_new)).astype(np.int32)
+
+ # Only record the segments longer than min_label_len
+ points1 = segments[:, :2]
+ points2 = segments[:, 2:]
+ seg_len = np.sqrt(np.sum((points1 - points2) ** 2, axis=1))
+ label_segments = segments[seg_len >= min_label_len, :]
+
+ # Get all junctions from label segments
+ junctions_all = np.concatenate(
+ (label_segments[:, :2], label_segments[:, 2:]), axis=0)
+ if junctions_all.shape[0] == 0:
+ junc_points = None
+ line_map = None
+
+ # Get all unique junction points
+ else:
+ junc_points = np.unique(junctions_all, axis=0)
+ # Generate line map from points and segments
+ line_map = get_line_map(junc_points, label_segments)
+
+ # Draw lines on the boundaries of the stripes at random
+ nb_rows = random_state.randint(2, 5)
+ nb_cols = random_state.randint(2, col + 2)
+ thickness = random_state.randint(min_dim * 0.01, min_dim * 0.011)
+ for _ in range(nb_rows):
+ row_idx = random_state.choice([0, col + 1])
+ col_idx1 = random_state.randint(col + 1)
+ col_idx2 = random_state.randint(col + 1)
+ color = get_random_color(background_color)
+ cv.line(img, (warped_points[row_idx + col_idx1, 0],
+ warped_points[row_idx + col_idx1, 1]),
+ (warped_points[row_idx + col_idx2, 0],
+ warped_points[row_idx + col_idx2, 1]),
+ color, thickness)
+
+ for _ in range(nb_cols):
+ col_idx = random_state.randint(col + 1)
+ color = get_random_color(background_color)
+ cv.line(img, (warped_points[col_idx, 0],
+ warped_points[col_idx, 1]),
+ (warped_points[col_idx + col + 1, 0],
+ warped_points[col_idx + col + 1, 1]),
+ color, thickness)
+
+ # Keep only the points inside the image
+ # points = keep_points_inside(warped_points, img.shape[:2])
+ return {
+ "points": junc_points,
+ "line_map": line_map
+ }
+
+
+def draw_cube(img, min_size_ratio=0.2, min_label_len=64,
+ scale_interval=(0.4, 0.6), trans_interval=(0.5, 0.2)):
+ """ Draw a 2D projection of a cube and output the visible juntions.
+ Parameters:
+ min_size_ratio: min(img.shape) * min_size_ratio is the smallest
+ achievable cube side size
+ scale_interval: the scale is between scale_interval[0] and
+ scale_interval[0]+scale_interval[1]
+ trans_interval: the translation is between img.shape*trans_interval[0]
+ and img.shape*(trans_interval[0] + trans_interval[1])
+ """
+ # Generate a cube and apply to it an affine transformation
+ # The order matters!
+ # The indices of two adjacent vertices differ only of one bit (Gray code)
+ background_color = int(np.mean(img))
+ min_dim = min(img.shape[:2])
+ min_side = min_dim * min_size_ratio
+ lx = min_side + random_state.rand() * 2 * min_dim / 3 # dims of the cube
+ ly = min_side + random_state.rand() * 2 * min_dim / 3
+ lz = min_side + random_state.rand() * 2 * min_dim / 3
+ cube = np.array([[0, 0, 0],
+ [lx, 0, 0],
+ [0, ly, 0],
+ [lx, ly, 0],
+ [0, 0, lz],
+ [lx, 0, lz],
+ [0, ly, lz],
+ [lx, ly, lz]])
+ rot_angles = random_state.rand(3) * 3 * math.pi / 10. + math.pi / 10.
+ rotation_1 = np.array([[math.cos(rot_angles[0]),
+ -math.sin(rot_angles[0]), 0],
+ [math.sin(rot_angles[0]),
+ math.cos(rot_angles[0]), 0],
+ [0, 0, 1]])
+ rotation_2 = np.array([[1, 0, 0],
+ [0, math.cos(rot_angles[1]),
+ -math.sin(rot_angles[1])],
+ [0, math.sin(rot_angles[1]),
+ math.cos(rot_angles[1])]])
+ rotation_3 = np.array([[math.cos(rot_angles[2]), 0,
+ -math.sin(rot_angles[2])],
+ [0, 1, 0],
+ [math.sin(rot_angles[2]), 0,
+ math.cos(rot_angles[2])]])
+ scaling = np.array([[scale_interval[0] +
+ random_state.rand() * scale_interval[1], 0, 0],
+ [0, scale_interval[0] +
+ random_state.rand() * scale_interval[1], 0],
+ [0, 0, scale_interval[0] +
+ random_state.rand() * scale_interval[1]]])
+ trans = np.array([img.shape[1] * trans_interval[0] +
+ random_state.randint(-img.shape[1] * trans_interval[1],
+ img.shape[1] * trans_interval[1]),
+ img.shape[0] * trans_interval[0] +
+ random_state.randint(-img.shape[0] * trans_interval[1],
+ img.shape[0] * trans_interval[1]),
+ 0])
+ cube = trans + np.transpose(
+ np.dot(scaling, np.dot(rotation_1,
+ np.dot(rotation_2, np.dot(rotation_3, np.transpose(cube))))))
+
+ # The hidden corner is 0 by construction
+ # The front one is 7
+ cube = cube[:, :2] # project on the plane z=0
+ cube = cube.astype(int)
+ points = cube[1:, :] # get rid of the hidden corner
+
+ # Get the three visible faces
+ faces = np.array([[7, 3, 1, 5], [7, 5, 4, 6], [7, 6, 2, 3]])
+
+ # Get all visible line segments
+ segments = np.zeros([0, 4])
+ # Iterate through all the faces
+ for face_idx in range(faces.shape[0]):
+ face = faces[face_idx, :]
+ # Brute-forcely expand all the segments
+ segment = np.array(
+ [np.concatenate((cube[face[0]], cube[face[1]]), axis=0),
+ np.concatenate((cube[face[1]], cube[face[2]]), axis=0),
+ np.concatenate((cube[face[2]], cube[face[3]]), axis=0),
+ np.concatenate((cube[face[3]], cube[face[0]]), axis=0)])
+ segments = np.concatenate((segments, segment), axis=0)
+
+ # Select and refine the segments
+ segments_new = np.zeros([0, 4])
+ # Define image boundary polygon (in x y manner)
+ image_poly = shapely.geometry.Polygon(
+ [[0, 0], [img.shape[1] - 1, 0], [img.shape[1] - 1, img.shape[0] - 1],
+ [0, img.shape[0] - 1]])
+ for idx in range(segments.shape[0]):
+ # Get the line segment
+ seg_raw = segments[idx, :]
+ seg = shapely.geometry.LineString([seg_raw[:2], seg_raw[2:]])
+
+ # The line segment is just inside the image.
+ if seg.intersection(image_poly) == seg:
+ segments_new = np.concatenate(
+ (segments_new, seg_raw[None, ...]), axis=0)
+
+ # Intersect with the image.
+ elif seg.intersects(image_poly):
+ try:
+ p = np.array(
+ seg.intersection(image_poly).coords).reshape([-1, 4])
+ except:
+ continue
+ segment = p
+ segments_new = np.concatenate((segments_new, segment), axis=0)
+
+ else:
+ continue
+
+ segments = (np.round(segments_new)).astype(np.int32)
+
+ # Only record the segments longer than min_label_len
+ points1 = segments[:, :2]
+ points2 = segments[:, 2:]
+ seg_len = np.sqrt(np.sum((points1 - points2) ** 2, axis=1))
+ label_segments = segments[seg_len >= min_label_len, :]
+
+ # Get all junctions from label segments
+ junctions_all = np.concatenate(
+ (label_segments[:, :2], label_segments[:, 2:]), axis=0)
+ if junctions_all.shape[0] == 0:
+ junc_points = None
+ line_map = None
+
+ # Get all unique junction points
+ else:
+ junc_points = np.unique(junctions_all, axis=0)
+ # Generate line map from points and segments
+ line_map = get_line_map(junc_points, label_segments)
+
+ # Fill the faces and draw the contours
+ col_face = get_random_color(background_color)
+ for i in [0, 1, 2]:
+ cv.fillPoly(img, [cube[faces[i]].reshape((-1, 1, 2))],
+ col_face)
+ thickness = random_state.randint(min_dim * 0.003, min_dim * 0.015)
+ for i in [0, 1, 2]:
+ for j in [0, 1, 2, 3]:
+ col_edge = (col_face + 128
+ + random_state.randint(-64, 64))\
+ % 256 # color that constrats with the face color
+ cv.line(img, (cube[faces[i][j], 0], cube[faces[i][j], 1]),
+ (cube[faces[i][(j + 1) % 4], 0],
+ cube[faces[i][(j + 1) % 4], 1]),
+ col_edge, thickness)
+
+ return {
+ "points": junc_points,
+ "line_map": line_map
+ }
+
+
+def gaussian_noise(img):
+ """ Apply random noise to the image. """
+ cv.randu(img, 0, 255)
+ return {
+ "points": None,
+ "line_map": None
+ }
diff --git a/scalelsd/ssl/datasets/transforms/__init__.py b/scalelsd/ssl/datasets/transforms/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/scalelsd/ssl/datasets/transforms/homographic_transforms.py b/scalelsd/ssl/datasets/transforms/homographic_transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..885063a23a0c0817464b1ac554340f9e8a2409ee
--- /dev/null
+++ b/scalelsd/ssl/datasets/transforms/homographic_transforms.py
@@ -0,0 +1,359 @@
+"""
+This file implements the homographic transforms for data augmentation.
+Code adapted from https://github.com/rpautrat/SuperPoint
+"""
+import numpy as np
+from math import pi
+
+from ..synthetic_util import get_line_map, get_line_heatmap
+import cv2
+import copy
+import shapely.geometry
+
+
+def sample_homography(
+ shape, perspective=True, scaling=True, rotation=True,
+ translation=True, n_scales=5, n_angles=25, scaling_amplitude=0.1,
+ perspective_amplitude_x=0.1, perspective_amplitude_y=0.1,
+ patch_ratio=0.5, max_angle=pi/2, allow_artifacts=False,
+ translation_overflow=0.):
+ """
+ Computes the homography transformation between a random patch in the
+ original image and a warped projection with the same image size.
+ As in `tf.contrib.image.transform`, it maps the output point
+ (warped patch) to a transformed input point (original patch).
+ The original patch, initialized with a simple half-size centered crop,
+ is iteratively projected, scaled, rotated and translated.
+
+ Arguments:
+ shape: A rank-2 `Tensor` specifying the height and width of the original image.
+ perspective: A boolean that enables the perspective and affine transformations.
+ scaling: A boolean that enables the random scaling of the patch.
+ rotation: A boolean that enables the random rotation of the patch.
+ translation: A boolean that enables the random translation of the patch.
+ n_scales: The number of tentative scales that are sampled when scaling.
+ n_angles: The number of tentatives angles that are sampled when rotating.
+ scaling_amplitude: Controls the amount of scale.
+ perspective_amplitude_x: Controls the perspective effect in x direction.
+ perspective_amplitude_y: Controls the perspective effect in y direction.
+ patch_ratio: Controls the size of the patches used to create the homography.
+ max_angle: Maximum angle used in rotations.
+ allow_artifacts: A boolean that enables artifacts when applying the homography.
+ translation_overflow: Amount of border artifacts caused by translation.
+
+ Returns:
+ homo_mat: A numpy array of shape `[1, 3, 3]` corresponding to the
+ homography transform.
+ selected_scale: The selected scaling factor.
+ """
+ # Convert shape to ndarry
+ if not isinstance(shape, np.ndarray):
+ shape = np.array(shape)
+
+ # Corners of the output image
+ pts1 = np.array([[0., 0.], [0., 1.], [1., 1.], [1., 0.]])
+ # Corners of the input patch
+ margin = (1 - patch_ratio) / 2
+ pts2 = margin + np.array([[0, 0], [0, patch_ratio],
+ [patch_ratio, patch_ratio], [patch_ratio, 0]])
+
+ # Random perspective and affine perturbations
+ if perspective:
+ if not allow_artifacts:
+ perspective_amplitude_x = min(perspective_amplitude_x, margin)
+ perspective_amplitude_y = min(perspective_amplitude_y, margin)
+
+ # normal distribution with mean=0, std=perspective_amplitude_y/2
+ perspective_displacement = np.random.normal(
+ 0., perspective_amplitude_y/2, [1])
+ h_displacement_left = np.random.normal(
+ 0., perspective_amplitude_x/2, [1])
+ h_displacement_right = np.random.normal(
+ 0., perspective_amplitude_x/2, [1])
+ pts2 += np.stack([np.concatenate([h_displacement_left,
+ perspective_displacement], 0),
+ np.concatenate([h_displacement_left,
+ -perspective_displacement], 0),
+ np.concatenate([h_displacement_right,
+ perspective_displacement], 0),
+ np.concatenate([h_displacement_right,
+ -perspective_displacement], 0)])
+
+ # Random scaling: sample several scales, check collision with borders,
+ # randomly pick a valid one
+ if scaling:
+ if scaling_amplitude==-1:
+ # selected_scale = np.random.uniform(0.5, 2.0)
+ # [0.25, 0.50, 0.75, 1.25, 1.50, 2.0, 3.0, 4.0]
+ selected_scale = 1.0
+ center = np.mean(pts2, axis=0, keepdims=True)
+ pts2 = (pts2 - center) * selected_scale + center
+ else:
+ scales = np.concatenate(
+ [[1.], np.random.normal(1, scaling_amplitude/2, [n_scales])], 0)
+ center = np.mean(pts2, axis=0, keepdims=True)
+ scaled = (pts2 - center)[None, ...] * scales[..., None, None] + center
+ # all scales are valid except scale=1
+ if allow_artifacts:
+ valid = np.array(range(n_scales))
+ # Chech the valid scale
+ else:
+ valid = np.where(np.all((scaled >= 0.)
+ & (scaled < 1.), (1, 2)))[0]
+ # No valid scale found => recursively call
+ if valid.shape[0] == 0:
+ return sample_homography(
+ shape, perspective, scaling, rotation, translation,
+ n_scales, n_angles, scaling_amplitude,
+ perspective_amplitude_x, perspective_amplitude_y,
+ patch_ratio, max_angle, allow_artifacts, translation_overflow)
+
+ idx = valid[np.random.uniform(0., valid.shape[0], ()).astype(np.int32)]
+ pts2 = scaled[idx]
+
+ # Additionally save and return the selected scale.
+ selected_scale = scales[idx]
+
+ # Random translation
+ if translation:
+ t_min, t_max = np.min(pts2, axis=0), np.min(1 - pts2, axis=0)
+ if allow_artifacts:
+ t_min += translation_overflow
+ t_max += translation_overflow
+ pts2 += (np.stack([np.random.uniform(-t_min[0], t_max[0], ()),
+ np.random.uniform(-t_min[1],
+ t_max[1], ())]))[None, ...]
+
+ # Random rotation: sample several rotations, check collision with borders,
+ # randomly pick a valid one
+ if rotation:
+ angles = np.linspace(-max_angle, max_angle, n_angles)
+ # in case no rotation is valid
+ angles = np.concatenate([[0.], angles], axis=0)
+ center = np.mean(pts2, axis=0, keepdims=True)
+ rot_mat = np.reshape(np.stack(
+ [np.cos(angles), -np.sin(angles),
+ np.sin(angles), np.cos(angles)], axis=1), [-1, 2, 2])
+ rotated = np.matmul(
+ np.tile((pts2 - center)[None, ...], [n_angles+1, 1, 1]),
+ rot_mat) + center
+ if allow_artifacts:
+ # All angles are valid, except angle=0
+ valid = np.array(range(n_angles))
+ else:
+ valid = np.where(np.all((rotated >= 0.)
+ & (rotated < 1.), axis=(1, 2)))[0]
+
+ if valid.shape[0] == 0:
+ return sample_homography(
+ shape, perspective, scaling, rotation, translation,
+ n_scales, n_angles, scaling_amplitude,
+ perspective_amplitude_x, perspective_amplitude_y,
+ patch_ratio, max_angle, allow_artifacts, translation_overflow)
+
+ idx = valid[np.random.uniform(0., valid.shape[0],
+ ()).astype(np.int32)]
+ pts2 = rotated[idx]
+
+ # Rescale to actual size
+ shape = shape[::-1].astype(np.float32) # different convention [y, x]
+ pts1 *= shape[None, ...]
+ pts2 *= shape[None, ...]
+
+ def ax(p, q): return [p[0], p[1], 1, 0, 0, 0, -p[0] * q[0], -p[1] * q[0]]
+
+ def ay(p, q): return [0, 0, 0, p[0], p[1], 1, -p[0] * q[1], -p[1] * q[1]]
+
+ a_mat = np.stack([f(pts1[i], pts2[i]) for i in range(4)
+ for f in (ax, ay)], axis=0)
+ p_mat = np.transpose(np.stack([[pts2[i][j] for i in range(4)
+ for j in range(2)]], axis=0))
+ homo_vec, _, _, _ = np.linalg.lstsq(a_mat, p_mat, rcond=None)
+
+ # Compose the homography vector back to matrix
+ homo_mat = np.concatenate([
+ homo_vec[0:3, 0][None, ...], homo_vec[3:6, 0][None, ...],
+ np.concatenate((homo_vec[6], homo_vec[7], [1]),
+ axis=0)[None, ...]], axis=0)
+
+ return homo_mat, selected_scale
+
+
+def convert_to_line_segments(junctions, line_map):
+ """ Convert junctions and line map to line segments. """
+ # Copy the line map
+ line_map_tmp = copy.copy(line_map)
+
+ line_segments = np.zeros([0, 4])
+ for idx in range(junctions.shape[0]):
+ # If no connectivity, just skip it
+ if line_map_tmp[idx, :].sum() == 0:
+ continue
+ # Record the line segment
+ else:
+ for idx2 in np.where(line_map_tmp[idx, :] == 1)[0]:
+ p1 = junctions[idx, :]
+ p2 = junctions[idx2, :]
+ line_segments = np.concatenate(
+ (line_segments,
+ np.array([p1[0], p1[1], p2[0], p2[1]])[None, ...]),
+ axis=0)
+ # Update line_map
+ line_map_tmp[idx, idx2] = 0
+ line_map_tmp[idx2, idx] = 0
+
+ return line_segments
+
+
+def compute_valid_mask(image_size, homography,
+ border_margin, valid_mask=None):
+ # Warp the mask
+ if valid_mask is None:
+ initial_mask = np.ones(image_size)
+ else:
+ initial_mask = valid_mask
+ mask = cv2.warpPerspective(
+ initial_mask, homography, (image_size[1], image_size[0]),
+ flags=cv2.INTER_NEAREST)
+
+ # Optionally perform erosion
+ if border_margin > 0:
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
+ (border_margin*2, )*2)
+ mask = cv2.erode(mask, kernel)
+
+ # Perform dilation if border_margin is negative
+ if border_margin < 0:
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
+ (abs(int(border_margin))*2, )*2)
+ mask = cv2.dilate(mask, kernel)
+
+ return mask
+
+
+def warp_line_segment(line_segments, homography, image_size):
+ """ Warp the line segments using a homography. """
+ # Separate the line segements into 2N points to apply matrix operation
+ num_segments = line_segments.shape[0]
+
+ junctions = np.concatenate(
+ (line_segments[:, :2], # The first junction of each segment.
+ line_segments[:, 2:]), # The second junction of each segment.
+ axis=0)
+ # Convert to homogeneous coordinates
+ # Flip the junctions before converting to homogeneous (xy format)
+ junctions = np.flip(junctions, axis=1)
+ junctions = np.concatenate((junctions, np.ones([2*num_segments, 1])),
+ axis=1)
+ warped_junctions = np.matmul(homography, junctions.T).T
+
+ # Convert back to segments
+ warped_junctions = warped_junctions[:, :2] / warped_junctions[:, 2:]
+ # (Convert back to hw format)
+ warped_junctions = np.flip(warped_junctions, axis=1)
+ warped_segments = np.concatenate(
+ (warped_junctions[:num_segments, :],
+ warped_junctions[num_segments:, :]),
+ axis=1
+ )
+
+ # Check the intersections with the boundary
+ warped_segments_new = np.zeros([0, 4])
+ image_poly = shapely.geometry.Polygon(
+ [[0, 0], [image_size[1]-1, 0], [image_size[1]-1, image_size[0]-1],
+ [0, image_size[0]-1]])
+ for idx in range(warped_segments.shape[0]):
+ # Get the line segment
+ seg_raw = warped_segments[idx, :] # in HW format.
+ # Convert to shapely line (flip to xy format)
+ seg = shapely.geometry.LineString([np.flip(seg_raw[:2]),
+ np.flip(seg_raw[2:])])
+
+ # The line segment is just inside the image.
+ if seg.intersection(image_poly) == seg:
+ warped_segments_new = np.concatenate((warped_segments_new,
+ seg_raw[None, ...]), axis=0)
+
+ # Intersect with the image.
+ elif seg.intersects(image_poly):
+ # Check intersection
+ try:
+ p = np.array(
+ seg.intersection(image_poly).coords).reshape([-1, 4])
+ # If intersect at exact one point, just continue.
+ except:
+ continue
+ segment = np.concatenate([np.flip(p[0, :2]), np.flip(p[0, 2:],
+ axis=0)])[None, ...]
+ warped_segments_new = np.concatenate(
+ (warped_segments_new, segment), axis=0)
+
+ else:
+ continue
+
+ # warped_segments = (np.round(warped_segments_new)).astype(np.int32)
+
+ # return warped_segments
+ return warped_segments_new
+
+
+class homography_transform(object):
+ """ # Homography transformations. """
+ def __init__(self, image_size, homograpy_config,
+ border_margin=0, min_label_len=20):
+ self.homo_config = homograpy_config
+ self.image_size = image_size
+ self.target_size = (self.image_size[1], self.image_size[0])
+ self.border_margin = border_margin
+ if (min_label_len < 1) and isinstance(min_label_len, float):
+ raise ValueError("[Error] min_label_len should be in pixels.")
+ self.min_label_len = min_label_len
+
+ def __call__(self, input_image, junctions, line_map,
+ valid_mask=None, homo=None, scale=None):
+ # Sample one random homography or use the given one
+ if homo is None or scale is None:
+ homo, scale = sample_homography(self.image_size,
+ **self.homo_config)
+
+ # Warp the image
+ warped_image = cv2.warpPerspective(
+ input_image, homo, self.target_size, flags=cv2.INTER_LINEAR)
+
+ valid_mask = compute_valid_mask(self.image_size, homo,
+ self.border_margin, valid_mask)
+
+ # Convert junctions and line_map back to line segments
+ line_segments = convert_to_line_segments(junctions, line_map)
+
+ # Warp the segments and check the length.
+ # Adjust the min_label_length
+ warped_segments = warp_line_segment(line_segments, homo,
+ self.image_size)
+
+ # Convert back to junctions and line_map
+ junctions_new = np.concatenate((warped_segments[:, :2],
+ warped_segments[:, 2:]), axis=0)
+ if junctions_new.shape[0] == 0:
+ junctions_new = np.zeros([0, 2])
+ line_map = np.zeros([0, 0])
+ # warped_heatmap = np.zeros(self.image_size)
+ else:
+ junctions_new = np.unique(junctions_new, axis=0)
+
+ # Generate line map from points and segments
+ line_map = get_line_map(junctions_new,
+ warped_segments).astype(np.int32)
+ # Compute the heatmap
+ # warped_heatmap = get_line_heatmap(np.flip(junctions_new, axis=1),
+ # line_map, self.image_size)
+
+ return {
+ "junctions": junctions_new,
+ "warped_image": warped_image,
+ "valid_mask": valid_mask,
+ "line_map": line_map,
+ # "warped_heatmap": warped_heatmap,
+ "homo": homo,
+ "scale": scale
+ }
diff --git a/scalelsd/ssl/datasets/transforms/photometric_transforms.py b/scalelsd/ssl/datasets/transforms/photometric_transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..8fa44bf0efa93a47e5f8012988058f1cbd49324f
--- /dev/null
+++ b/scalelsd/ssl/datasets/transforms/photometric_transforms.py
@@ -0,0 +1,185 @@
+"""
+Common photometric transforms for data augmentation.
+"""
+import numpy as np
+from PIL import Image
+from torchvision import transforms as transforms
+import cv2
+
+
+# List all the available augmentations
+available_augmentations = [
+ 'additive_gaussian_noise',
+ 'additive_speckle_noise',
+ 'random_brightness',
+ 'random_contrast',
+ 'additive_shade',
+ 'motion_blur'
+]
+
+
+class additive_gaussian_noise(object):
+ """ Additive gaussian noise. """
+ def __init__(self, stddev_range=None):
+ # If std is not given, use the default setting
+ if stddev_range is None:
+ self.stddev_range = [5, 95]
+ else:
+ self.stddev_range = stddev_range
+
+ def __call__(self, input_image):
+ # Get the noise stddev
+ stddev = np.random.uniform(self.stddev_range[0], self.stddev_range[1])
+ noise = np.random.normal(0., stddev, size=input_image.shape)
+ noisy_image = (input_image + noise).clip(0., 255.)
+
+ return noisy_image
+
+
+class additive_speckle_noise(object):
+ """ Additive speckle noise. """
+ def __init__(self, prob_range=None):
+ # If prob range is not given, use the default setting
+ if prob_range is None:
+ self.prob_range = [0.0, 0.005]
+ else:
+ self.prob_range = prob_range
+
+ def __call__(self, input_image):
+ # Sample
+ prob = np.random.uniform(self.prob_range[0], self.prob_range[1])
+ sample = np.random.uniform(0., 1., size=input_image.shape)
+
+ # Get the mask
+ mask0 = sample <= prob
+ mask1 = sample >= (1 - prob)
+
+ # Mask the image (here we assume the image ranges from 0~255
+ noisy = input_image.copy()
+ noisy[mask0] = 0.
+ noisy[mask1] = 255.
+
+ return noisy
+
+
+class random_brightness(object):
+ """ Brightness change. """
+ def __init__(self, brightness=None):
+ # If the brightness is not given, use the default setting
+ if brightness is None:
+ self.brightness = 0.5
+ else:
+ self.brightness = brightness
+
+ # Initialize the transformer
+ self.transform = transforms.ColorJitter(brightness=self.brightness)
+
+ def __call__(self, input_image):
+ # Convert to PIL image
+ if isinstance(input_image, np.ndarray):
+ input_image = Image.fromarray(input_image.astype(np.uint8))
+
+ return np.array(self.transform(input_image))
+
+
+class random_contrast(object):
+ """ Additive contrast. """
+ def __init__(self, contrast=None):
+ # If the brightness is not given, use the default setting
+ if contrast is None:
+ self.contrast = 0.5
+ else:
+ self.contrast = contrast
+
+ # Initialize the transformer
+ self.transform = transforms.ColorJitter(contrast=self.contrast)
+
+ def __call__(self, input_image):
+ # Convert to PIL image
+ if isinstance(input_image, np.ndarray):
+ input_image = Image.fromarray(input_image.astype(np.uint8))
+
+ return np.array(self.transform(input_image))
+
+
+class additive_shade(object):
+ """ Additive shade. """
+ def __init__(self, nb_ellipses=20, transparency_range=None,
+ kernel_size_range=None):
+ self.nb_ellipses = nb_ellipses
+ if transparency_range is None:
+ self.transparency_range = [-0.5, 0.8]
+ else:
+ self.transparency_range = transparency_range
+
+ if kernel_size_range is None:
+ self.kernel_size_range = [250, 350]
+ else:
+ self.kernel_size_range = kernel_size_range
+
+ def __call__(self, input_image):
+ # ToDo: if we should convert to numpy array first.
+ min_dim = min(input_image.shape[:2]) / 4
+ mask = np.zeros(input_image.shape[:2], np.uint8)
+ for i in range(self.nb_ellipses):
+ ax = int(max(np.random.rand() * min_dim, min_dim / 5))
+ ay = int(max(np.random.rand() * min_dim, min_dim / 5))
+ max_rad = max(ax, ay)
+ x = np.random.randint(max_rad, input_image.shape[1] - max_rad)
+ y = np.random.randint(max_rad, input_image.shape[0] - max_rad)
+ angle = np.random.rand() * 90
+ cv2.ellipse(mask, (x, y), (ax, ay), angle, 0, 360, 255, -1)
+
+ transparency = np.random.uniform(*self.transparency_range)
+ kernel_size = np.random.randint(*self.kernel_size_range)
+
+ # kernel_size has to be odd
+ if (kernel_size % 2) == 0:
+ kernel_size += 1
+ mask = cv2.GaussianBlur(mask.astype(np.float32),
+ (kernel_size, kernel_size), 0)
+ shaded = (input_image[..., None]
+ * (1 - transparency * mask[..., np.newaxis]/255.))
+ shaded = np.clip(shaded, 0, 255)
+
+ return np.reshape(shaded, input_image.shape)
+
+
+class motion_blur(object):
+ """ Motion blur. """
+ def __init__(self, max_kernel_size=10):
+ self.max_kernel_size = max_kernel_size
+
+ def __call__(self, input_image):
+ # Either vertical, horizontal or diagonal blur
+ mode = np.random.choice(['h', 'v', 'diag_down', 'diag_up'])
+ ksize = np.random.randint(
+ 0, int(round((self.max_kernel_size + 1) / 2))) * 2 + 1
+ center = int((ksize - 1) / 2)
+ kernel = np.zeros((ksize, ksize))
+ if mode == 'h':
+ kernel[center, :] = 1.
+ elif mode == 'v':
+ kernel[:, center] = 1.
+ elif mode == 'diag_down':
+ kernel = np.eye(ksize)
+ elif mode == 'diag_up':
+ kernel = np.flip(np.eye(ksize), 0)
+ var = ksize * ksize / 16.
+ grid = np.repeat(np.arange(ksize)[:, np.newaxis], ksize, axis=-1)
+ gaussian = np.exp(-(np.square(grid - center)
+ + np.square(grid.T - center)) / (2. * var))
+ kernel *= gaussian
+ kernel /= np.sum(kernel)
+ blurred = cv2.filter2D(input_image, -1, kernel)
+
+ return np.reshape(blurred, input_image.shape)
+
+
+class normalize_image(object):
+ """ Image normalization to the range [0, 1]. """
+ def __init__(self):
+ self.normalize_value = 255
+
+ def __call__(self, input_image):
+ return (input_image / self.normalize_value).astype(np.float32)
diff --git a/scalelsd/ssl/datasets/transforms/utils.py b/scalelsd/ssl/datasets/transforms/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..45c2e6fc1eacb631dff0425e7c0fe966a239b518
--- /dev/null
+++ b/scalelsd/ssl/datasets/transforms/utils.py
@@ -0,0 +1,121 @@
+"""
+Some useful functions for dataset pre-processing
+"""
+import cv2
+import numpy as np
+import shapely.geometry as sg
+
+from ..synthetic_util import get_line_map
+from . import homographic_transforms as homoaug
+
+
+def random_scaling(image, junctions, line_map, scale=1., h_crop=0, w_crop=0):
+ H, W = image.shape[:2]
+ H_scale, W_scale = round(H * scale), round(W * scale)
+
+ # Nothing to do if the scale is too close to 1
+ if H_scale == H and W_scale == W:
+ return (image, junctions, line_map, np.ones([H, W], dtype=np.int32))
+
+ # Zoom-in => resize and random crop
+ if scale >= 1.:
+ image_big = cv2.resize(image, (W_scale, H_scale),
+ interpolation=cv2.INTER_LINEAR)
+ # Crop the image
+ image = image_big[h_crop:h_crop+H, w_crop:w_crop+W, ...]
+ valid_mask = np.ones([H, W], dtype=np.int32)
+
+ # Process junctions
+ junctions, line_map = process_junctions_and_line_map(
+ h_crop, w_crop, H, W, H_scale, W_scale,
+ junctions, line_map, "zoom-in")
+ # Zoom-out => resize and pad
+ else:
+ image_shape_raw = image.shape
+ image_small = cv2.resize(image, (W_scale, H_scale),
+ interpolation=cv2.INTER_AREA)
+ # Decide the pasting location
+ h_start = round((H - H_scale) / 2)
+ w_start = round((W - W_scale) / 2)
+ # Paste the image to the middle
+ image = np.zeros(image_shape_raw, dtype='float')
+ image[h_start:h_start+H_scale,
+ w_start:w_start+W_scale, ...] = image_small
+ valid_mask = np.zeros([H, W], dtype=np.int32)
+ valid_mask[h_start:h_start+H_scale, w_start:w_start+W_scale] = 1
+
+ # Process the junctions
+ junctions, line_map = process_junctions_and_line_map(
+ h_start, w_start, H, W, H_scale, W_scale,
+ junctions, line_map, "zoom-out")
+
+ return image, junctions, line_map, valid_mask
+
+
+def process_junctions_and_line_map(h_start, w_start, H, W, H_scale, W_scale,
+ junctions, line_map, mode="zoom-in"):
+ if mode == "zoom-in":
+ junctions[:, 0] = junctions[:, 0] * H_scale / H
+ junctions[:, 1] = junctions[:, 1] * W_scale / W
+ line_segments = homoaug.convert_to_line_segments(junctions, line_map)
+ # Crop segments to the new boundaries
+ line_segments_new = np.zeros([0, 4])
+ image_poly = sg.Polygon(
+ [[w_start, h_start],
+ [w_start+W, h_start],
+ [w_start+W, h_start+H],
+ [w_start, h_start+H]
+ ])
+ for idx in range(line_segments.shape[0]):
+ # Get the line segment
+ seg_raw = line_segments[idx, :] # in HW format.
+ # Convert to shapely line (flip to xy format)
+ seg = sg.LineString([np.flip(seg_raw[:2]),
+ np.flip(seg_raw[2:])])
+ # The line segment is just inside the image.
+ if seg.intersection(image_poly) == seg:
+ line_segments_new = np.concatenate(
+ (line_segments_new, seg_raw[None, ...]), axis=0)
+ # Intersect with the image.
+ elif seg.intersects(image_poly):
+ # Check intersection
+ try:
+ p = np.array(
+ seg.intersection(image_poly).coords).reshape([-1, 4])
+ # If intersect at exact one point, just continue.
+ except:
+ continue
+ segment = np.concatenate([np.flip(p[0, :2]), np.flip(p[0, 2:],
+ axis=0)])[None, ...]
+ line_segments_new = np.concatenate(
+ (line_segments_new, segment), axis=0)
+ else:
+ continue
+ line_segments_new = (np.round(line_segments_new)).astype(np.int32)
+ # Filter segments with 0 length
+ segment_lens = np.linalg.norm(
+ line_segments_new[:, :2] - line_segments_new[:, 2:], axis=-1)
+ seg_mask = segment_lens != 0
+ line_segments_new = line_segments_new[seg_mask, :]
+ # Convert back to junctions and line_map
+ junctions_new = np.concatenate(
+ (line_segments_new[:, :2], line_segments_new[:, 2:]), axis=0)
+ if junctions_new.shape[0] == 0:
+ junctions_new = np.zeros([0, 2])
+ line_map = np.zeros([0, 0])
+ else:
+ junctions_new = np.unique(junctions_new, axis=0)
+ # Generate line map from points and segments
+ line_map = get_line_map(junctions_new,
+ line_segments_new).astype(np.int32)
+ junctions_new[:, 0] -= h_start
+ junctions_new[:, 1] -= w_start
+ junctions = junctions_new
+ elif mode == "zoom-out":
+ # Process the junctions
+ junctions[:, 0] = (junctions[:, 0] * H_scale / H) + h_start
+ junctions[:, 1] = (junctions[:, 1] * W_scale / W) + w_start
+ else:
+ raise ValueError("[Error] unknown mode...")
+
+ return junctions, line_map
diff --git a/scalelsd/ssl/datasets/wireframe.py b/scalelsd/ssl/datasets/wireframe.py
new file mode 100644
index 0000000000000000000000000000000000000000..5218718e56f7a3244545733023adf1bc1fa9de6d
--- /dev/null
+++ b/scalelsd/ssl/datasets/wireframe.py
@@ -0,0 +1,19 @@
+import torch
+from torch.utils.data import Dataset
+
+import os.path as osp
+import json
+import cv2
+from skimage import io
+from PIL import Image
+import numpy as np
+import random
+from torch.utils.data.dataloader import default_collate
+from torch.utils.data.dataloader import DataLoader
+import matplotlib.pyplot as plt
+from torchvision.transforms import functional as F
+import copy
+
+
+# class WireframeDataset(Dataset):
+# def __init__(self, )
\ No newline at end of file
diff --git a/scalelsd/ssl/datasets/wireframe_dataset.py b/scalelsd/ssl/datasets/wireframe_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..c437a1bf97b5f486ad5e6d56724356e25bb1f0ea
--- /dev/null
+++ b/scalelsd/ssl/datasets/wireframe_dataset.py
@@ -0,0 +1,1231 @@
+"""
+This file implements the wireframe dataset object for pytorch.
+Some parts of the code are adapted from https://github.com/zhou13/lcnn
+"""
+import os
+import math
+import copy
+from skimage.io import imread
+from skimage import color
+import PIL
+import numpy as np
+import h5py
+import cv2
+import pickle
+import torch
+import torch.utils.data.dataloader as torch_loader
+from torch.utils.data import Dataset
+from torchvision import transforms
+
+from ..config.project_config import Config as cfg
+from .transforms import photometric_transforms as photoaug
+from .transforms import homographic_transforms as homoaug
+from .transforms.utils import random_scaling
+from .synthetic_util import get_line_heatmap
+from ..misc.train_utils import parse_h5_data
+from ..misc.geometry_utils import warp_points, mask_points
+from tqdm import tqdm
+
+def wireframe_collate_fn(batch):
+ """ Customized collate_fn for wireframe dataset. """
+ batch_keys = ["image", "junction_map", "valid_mask", "heatmap",
+ "heatmap_pos", "heatmap_neg", "homography",
+ "line_points", "line_indices"]
+ list_keys = ["junctions", "line_map", "line_map_pos",
+ "line_map_neg", "file_key"]
+
+ outputs = {}
+ for data_key in batch[0].keys():
+ batch_match = sum([_ in data_key for _ in batch_keys])
+ list_match = sum([_ in data_key for _ in list_keys])
+ # print(batch_match, list_match)
+ if batch_match > 0 and list_match == 0:
+ outputs[data_key] = torch_loader.default_collate(
+ [b[data_key] for b in batch])
+ elif batch_match == 0 and list_match > 0:
+ outputs[data_key] = [b[data_key] for b in batch]
+ elif batch_match == 0 and list_match == 0:
+ continue
+ else:
+ raise ValueError(
+ "[Error] A key matches batch keys and list keys simultaneously.")
+
+ return outputs
+
+
+class WireframeDataset(Dataset):
+ def __init__(self, mode="train", config=None):
+ super(WireframeDataset, self).__init__()
+ if not mode in ["train", "test"]:
+ raise ValueError(
+ "[Error] Unknown mode for Wireframe dataset. Only 'train' and 'test'.")
+ self.mode = mode
+
+ if config is None:
+ self.config = self.get_default_config()
+ else:
+ self.config = config
+ # Also get the default config
+ self.default_config = self.get_default_config()
+
+ # Get cache setting
+ self.dataset_name = self.get_dataset_name()
+ self.cache_name = self.get_cache_name()
+ self.cache_path = cfg.wireframe_cache_path
+
+ # Get the ground truth source
+ self.gt_source = self.config.get("gt_source_%s"%(self.mode),
+ "official")
+
+ if not self.gt_source == "official":
+ # Convert gt_source to full path
+ self.gt_source = os.path.join(cfg.export_dataroot, self.gt_source)
+ # Check the full path exists
+ if not os.path.exists(self.gt_source):
+ raise ValueError(
+ "[Error] The specified ground truth source does not exist.")
+
+
+ # Get the filename dataset
+ print("[Info] Initializing wireframe dataset...")
+ self.filename_dataset, self.datapoints = self.construct_dataset()
+
+ # Get dataset length
+ self.dataset_length = len(self.datapoints)
+ # self.dataset_length = len(self.datapoints)//4
+
+ # Get repeatability evaluation set
+ if self.mode == "test" and self.config.get("evaluation", None) is not None:
+ # Get the cache name
+ tmp = self.cache_name.split(self.mode)
+ self.rep_i_cache_name = tmp[0] + self.mode + "_rep_i" + tmp[1]
+ self.rep_v_cache_name = tmp[0] + self.mode + "_rep_v" + tmp[1]
+
+ # Get the repeatability config
+ self.rep_config = self.config["evaluation"]["repeatability"]
+ self.rep_eval_dataset = self.construct_rep_eval_dataset()
+ self.rep_eval_datapoints = self.get_rep_eval_datapoints()
+ # Print some info
+ print("[Info] Successfully initialized dataset")
+ print("\t Name: wireframe")
+ print("\t Mode: %s" %(self.mode))
+ print("\t Gt: %s" %(self.config.get("gt_source_%s"%(self.mode),
+ "official")))
+ print("\t Counts: %d" %(self.dataset_length))
+ print("----------------------------------------")
+
+ #######################################
+ ## Dataset construction related APIs ##
+ #######################################
+
+ def construct_dataset(self):
+ """ Construct the dataset (from scratch or from cache). """
+ # Check if the filename cache exists
+ # If cache exists, load from cache
+ if self._check_dataset_cache():
+ print("\t Found filename cache %s at %s"%(self.cache_name,
+ self.cache_path))
+ print("\t Load filename cache...")
+ filename_dataset, datapoints = self.get_filename_dataset_from_cache()
+ # If not, initialize dataset from scratch
+ else:
+ print("\t Can't find filename cache ...")
+ print("\t Create filename dataset from scratch...")
+ filename_dataset, datapoints = self.get_filename_dataset()
+ print("\t Create filename dataset cache...")
+ self.create_filename_dataset_cache(filename_dataset, datapoints)
+
+ return filename_dataset, datapoints
+
+ def create_filename_dataset_cache(self, filename_dataset, datapoints):
+ """ Create filename dataset cache for faster initialization. """
+ # Check cache path exists
+ if not os.path.exists(self.cache_path):
+ os.makedirs(self.cache_path)
+
+ cache_file_path = os.path.join(self.cache_path, self.cache_name)
+ data = {
+ "filename_dataset": filename_dataset,
+ "datapoints": datapoints
+ }
+ with open(cache_file_path, "wb") as f:
+ pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)
+
+ def get_filename_dataset_from_cache(self):
+ """ Get filename dataset from cache. """
+ # Load from pkl cache
+ cache_file_path = os.path.join(self.cache_path, self.cache_name)
+ with open(cache_file_path, "rb") as f:
+ data = pickle.load(f)
+
+ return data["filename_dataset"], data["datapoints"]
+
+ def get_filename_dataset(self):
+ # Get the path to the dataset
+ if self.mode == "train":
+ dataset_path = os.path.join(cfg.wireframe_dataroot, "train")
+ elif self.mode == "test":
+ dataset_path = os.path.join(cfg.wireframe_dataroot, "valid")
+
+ # Get paths to all image files
+ image_paths = sorted([os.path.join(dataset_path, _)
+ for _ in os.listdir(dataset_path)\
+ if os.path.splitext(_)[-1] == ".png"])
+ # Get the shared prefix
+ prefix_paths = [_.split(".png")[0] for _ in image_paths]
+
+ # Get the label paths (different procedure for different split)
+ if self.mode == "train":
+ label_paths = [_ + "_label.npz" for _ in prefix_paths]
+ else:
+ label_paths = [_ + "_label.npz" for _ in prefix_paths]
+ mat_paths = [p[:-2] + "_line.mat" for p in prefix_paths]
+
+ # Verify all the images and labels exist
+ for idx in range(len(image_paths)):
+ image_path = image_paths[idx]
+ label_path = label_paths[idx]
+ if (not (os.path.exists(image_path)
+ and os.path.exists(label_path))):
+ raise ValueError(
+ "[Error] The image and label do not exist. %s"%(image_path))
+ # Further verify mat paths for test split
+ if self.mode == "test":
+ mat_path = mat_paths[idx]
+ if not os.path.exists(mat_path):
+ raise ValueError(
+ "[Error] The mat file does not exist. %s"%(mat_path))
+
+ # Construct the filename dataset
+ num_pad = int(math.ceil(math.log10(len(image_paths))) + 1)
+ filename_dataset = {}
+ for idx in range(len(image_paths)):
+ # Get the file key
+ key = self.get_padded_filename(num_pad, idx)
+
+ filename_dataset[key] = {
+ "image": image_paths[idx],
+ "label": label_paths[idx]
+ }
+
+ # Get the datapoints
+ datapoints = list(sorted(filename_dataset.keys()))
+
+ return filename_dataset, datapoints
+
+ def get_dataset_name(self):
+ """ Get dataset name from dataset config / default config. """
+ if self.config["dataset_name"] is None:
+ dataset_name = self.default_config["dataset_name"] + "_%s" % self.mode
+ else:
+ dataset_name = self.config["dataset_name"] + "_%s" % self.mode
+
+ return dataset_name
+
+ def get_cache_name(self):
+ """ Get cache name from dataset config / default config. """
+ if self.config["dataset_name"] is None:
+ dataset_name = self.default_config["dataset_name"] + "_%s" % self.mode
+ else:
+ dataset_name = self.config["dataset_name"] + "_%s" % self.mode
+ # Compose cache name
+ cache_name = dataset_name + "_cache.pkl"
+
+ return cache_name
+
+ @staticmethod
+ def get_padded_filename(num_pad, idx):
+ """ Get the padded filename using adaptive padding. """
+ file_len = len("%d" % (idx))
+ filename = "0" * (num_pad - file_len) + "%d" % (idx)
+
+ return filename
+
+ def get_default_config(self):
+ """ Get the default configuration. """
+ return {
+ "dataset_name": "wireframe",
+ "add_augmentation_to_all_splits": False,
+ "preprocessing": {
+ "resize": [240, 320],
+ "blur_size": 11
+ },
+ "augmentation":{
+ "photometric":{
+ "enable": False
+ },
+ "homographic":{
+ "enable": False
+ },
+ },
+ }
+
+ ###########################################
+ ## Repeatability evaluation related APIs ##
+ ###########################################
+ # Construct repeatability evaluation dataset (from scratch or from cache)
+ def construct_rep_eval_dataset(self):
+ rep_eval_dataset = {}
+ # Check if viewpoint and illumination cache exists
+ if self.rep_config["photometric"]["enable"]:
+ if self._check_rep_eval_dataset_cache(split="i"):
+ print("\t Found repeatability illumination cache %s at %s"%(self.rep_i_cache_name, self.cache_path))
+ print("\t Load repeatability illumination cache...")
+ rep_i_keymap, rep_i_dataset_name = self.get_rep_eval_dataset_from_cache(split="i")
+ else:
+ print("\t Can't find repeatability illumination cache ...")
+ print("\t Create repeatability illumination dataset from scratch...")
+ rep_i_keymap, rep_i_dataset_name = self.get_rep_eval_dataset(split="i")
+ print("\t Create filename dataset cache...")
+ self.create_rep_eval_dataset_cache("i", rep_i_keymap, rep_i_dataset_name)
+ else:
+ rep_i_keymap = None
+ rep_i_dataset_name = None
+
+ rep_eval_dataset["illumination"] = {
+ "keymap": rep_i_keymap,
+ "dataset_name": rep_i_dataset_name
+ }
+
+ if self.rep_config["homographic"]["enable"]:
+ if self._check_rep_eval_dataset_cache(split="v"):
+ print("\t Found repeatability viewpoint cache %s at %s"%(self.rep_v_cache_name, self.cache_path))
+ print("\t Load repeatability viewpoint cache...")
+ rep_v_keymap, rep_v_dataset_name = self.get_rep_eval_dataset_from_cache(split="v")
+ else:
+ print("\t Can't find repeatability viewpoint cache ...")
+ print("\t Create repeatability viewpoint dataset from scratch...")
+ rep_v_keymap, rep_v_dataset_name = self.get_rep_eval_dataset(split="v")
+ print("\t Create filename dataset cache...")
+ self.create_rep_eval_dataset_cache("v", rep_v_keymap, rep_v_dataset_name)
+ else:
+ rep_v_keymap = None
+ rep_v_dataset_name = None
+
+ rep_eval_dataset["viewpoint"] = {
+ "keymap": rep_v_keymap,
+ "dataset_name": rep_v_dataset_name
+ }
+
+ return rep_eval_dataset
+
+ # Create filename dataset cache for faster initialization
+ def create_rep_eval_dataset_cache(self, split, keymap, dataset_name):
+ # Check cache path exists
+ if not os.path.exists(self.cache_path):
+ os.makedirs(self.cache_path)
+
+ if split == "i":
+ cache_file_path = os.path.join(self.cache_path, self.rep_i_cache_name)
+ elif split == "v":
+ cache_file_path = os.path.join(self.cache_path, self.rep_v_cache_name)
+ else:
+ raise ValueError("[Error] Unknown split for repeatability evaluation.")
+
+ data = {
+ "keymap": keymap,
+ "dataset_name": dataset_name
+ }
+ with open(cache_file_path, "wb") as f:
+ pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)
+
+ # Get filename dataset from cache
+ def get_rep_eval_dataset_from_cache(self, split):
+ # Load from pkl cache
+ if split == "i":
+ cache_file_path = os.path.join(self.cache_path, self.rep_i_cache_name)
+ elif split == "v":
+ cache_file_path = os.path.join(self.cache_path, self.rep_v_cache_name)
+ else:
+ raise ValueError("[Error] Unknown split for repeatability evaluation.")
+
+ with open(cache_file_path, "rb") as f:
+ data = pickle.load(f)
+
+ return data["keymap"], data["dataset_name"]
+
+ # Initialize the repeatability evaluation dataset from scratch
+ def get_rep_eval_dataset(self, split):
+ image_shape = self.config["preprocessing"]["resize"]
+
+ # Initialize the illumination set
+ if split == "i":
+ # Set the random seed before continuing
+ seed = self.rep_config["seed"]
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+
+ raise NotImplementedError
+
+ # Initialize the viewpoint set
+ elif split == "v":
+ # Set the random seed before continuing
+ seed = self.rep_config["seed"]
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+
+ v_keymap = {}
+ # Get the name for the output h5 dataset
+ v_dataset_name = self.rep_v_cache_name.split(".pkl")[0] + ".h5"
+ v_dataset_path = os.path.join(self.cache_path, v_dataset_name)
+ with h5py.File(v_dataset_path, "w") as f:
+ # Iterate through all the file_key in test set
+ for idx, key in enumerate(tqdm(list(self.filename_dataset.keys()), ascii=True)):
+ # Sample N random homography
+ file_key_lst = []
+ for i in range(self.rep_config["homographic"]["num_samples"]):
+ file_key = key + "_" + str(i)
+
+ # Sample a random homography
+ homo_mat, _ = homoaug.sample_homography(image_shape,
+ **self.rep_config["homographic"]["params"])
+
+ file_key_lst.append(file_key)
+ f.create_dataset(file_key, data=homo_mat, compression="gzip")
+
+ v_keymap[key] = file_key_lst
+
+ return v_keymap, v_dataset_name
+
+ else:
+ raise ValueError("[Error] Unknow split for repeatability evaluation.")
+
+ # Convert ref image and warped images to list of evaluation pairs
+ def get_rep_eval_datapoints(self):
+ datapoints = {
+ "illumination": [],
+ "viewpoint": []
+ }
+
+ # Iterate through all the ref image
+ if self.rep_eval_dataset["illumination"]["keymap"] is not None:
+ for ref_key in sorted(self.rep_eval_dataset["illumination"]["keymap"].keys()):
+ pair_lst = [[ref_key, _] for _ in self.rep_eval_dataset["illumination"]["keymap"][ref_key]]
+ datapoints["illumination"] += pair_lst
+
+ if self.rep_eval_dataset["viewpoint"]["keymap"] is not None:
+ for ref_key in sorted(self.rep_eval_dataset["viewpoint"]["keymap"].keys()):
+ pair_lst = [[ref_key, _] for _ in self.rep_eval_dataset["viewpoint"]["keymap"][ref_key]]
+ datapoints["viewpoint"] += pair_lst
+
+ return datapoints
+
+
+ ###########################################
+ ## Repeatability evaluation related APIs ##
+ ###########################################
+ # Get the corresponding data according to the "index in rep_eval_datapoints".
+ def get_rep_eval_data(self, split, idx):
+ assert split in ["viewpoint", "illumination"]
+ datapoint = self.rep_eval_datapoints[split][idx]
+
+ # Get reference image
+ ref_key = datapoint[0]
+ # Get the data paths
+ data_path = self.filename_dataset[ref_key]
+ # Read in the image and npz labels (but haven't applied any transform)
+ image = imread(data_path["image"])
+
+ # Resize the image before photometric and homographical augmentations
+ image_size = image.shape[:2]
+ if not(list(image_size) == self.config["preprocessing"]["resize"]):
+ # Resize the image and the point location.
+ size_old = list(image.shape)[:2] # Only H and W dimensions
+
+ image = cv2.resize(image, tuple(self.config['preprocessing']['resize'][::-1]),
+ interpolation=cv2.INTER_LINEAR)
+ image = np.array(image, dtype=np.uint8)
+
+ # Optionally convert the image to grayscale
+ if self.config["gray_scale"]:
+ image = (color.rgb2gray(image) *255.).astype(np.uint8)
+
+ image_transform = photoaug.normalize_image()
+ image = image_transform(image)
+
+ # Get target image
+ if split == "viewpoint":
+ target_key = datapoint[1]
+ dataset_path = os.path.join(self.cache_path, self.rep_eval_dataset[split]["dataset_name"])
+
+ with h5py.File(dataset_path, "r") as f:
+ homo_mat = np.array(f[target_key])
+
+ # Warp the image
+ target_size = (image.shape[1], image.shape[0])
+ target_image = cv2.warpPerspective(image, homo_mat, target_size,
+ flags=cv2.INTER_LINEAR)
+
+ else:
+ raise NotImplementedError
+
+ # Convert to tensor and return the results
+ to_tensor = transforms.ToTensor()
+
+ return {
+ "ref_image": to_tensor(image),
+ "ref_key": ref_key,
+ "target_image": to_tensor(target_image),
+ "target_key": target_key,
+ "homo_mat": homo_mat
+ }
+
+ ############################################
+ ## Pytorch and preprocessing related APIs ##
+ ############################################
+ # Get data from the information from filename dataset
+ @staticmethod
+ def get_data_from_path(data_path):
+ output = {}
+
+ # Get image data
+ image_path = data_path["image"]
+ image = imread(image_path)
+ output["image"] = image
+
+ # Get the npz label
+ """ Data entries in the npz file
+ jmap: [J, H, W] Junction heat map (H and W are 4x smaller)
+ joff: [J, 2, H, W] Junction offset within each pixel (Not sure about offsets)
+ lmap: [H, W] Line heat map with anti-aliasing (H and W are 4x smaller)
+ junc: [Na, 3] Junction coordinates (coordinates from 0~128 => 4x smaller.)
+ Lpos: [M, 2] Positive lines represented with junction indices
+ Lneg: [M, 2] Negative lines represented with junction indices
+ lpos: [Np, 2, 3] Positive lines represented with junction coordinates
+ lneg: [Nn, 2, 3] Negative lines represented with junction coordinates
+ """
+ label_path = data_path["label"]
+ label = np.load(label_path)
+ for key in list(label.keys()):
+ output[key] = label[key]
+
+ # If there's "line_mat" entry.
+ # TODO: How to process mat data
+ if data_path.get("line_mat") is not None:
+ raise NotImplementedError
+
+ return output
+
+ @staticmethod
+ def convert_line_map(lcnn_line_map, num_junctions):
+ """ Convert the line_pos or line_neg
+ (represented by two junction indexes) to our line map. """
+ # Initialize empty line map
+ line_map = np.zeros([num_junctions, num_junctions])
+
+ # Iterate through all the lines
+ for idx in range(lcnn_line_map.shape[0]):
+ index1 = lcnn_line_map[idx, 0]
+ index2 = lcnn_line_map[idx, 1]
+
+ line_map[index1, index2] = 1
+ line_map[index2, index1] = 1
+
+ return line_map
+
+ @staticmethod
+ def junc_to_junc_map(junctions, image_size):
+ """ Convert junction points to junction maps. """
+ junctions = np.round(junctions).astype(np.int32)
+ # Clip the boundary by image size
+ junctions[:, 0] = np.clip(junctions[:, 0], 0., image_size[0]-1)
+ junctions[:, 1] = np.clip(junctions[:, 1], 0., image_size[1]-1)
+
+ # Create junction map
+ junc_map = np.zeros([image_size[0], image_size[1]])
+ junc_map[junctions[:, 0], junctions[:, 1]] = 1
+
+ return junc_map[..., None].astype(np.int32)
+
+ def parse_transforms(self, names, all_transforms):
+ """ Parse the transform. """
+ trans = all_transforms if (names == 'all') \
+ else (names if isinstance(names, list) else [names])
+ assert set(trans) <= set(all_transforms)
+ return trans
+
+ def get_photo_transform(self):
+ """ Get list of photometric transforms (according to the config). """
+ # Get the photometric transform config
+ photo_config = self.config["augmentation"]["photometric"]
+ if not photo_config["enable"]:
+ raise ValueError(
+ "[Error] Photometric augmentation is not enabled.")
+
+ # Parse photometric transforms
+ trans_lst = self.parse_transforms(photo_config["primitives"],
+ photoaug.available_augmentations)
+ trans_config_lst = [photo_config["params"].get(p, {})
+ for p in trans_lst]
+
+ # List of photometric augmentation
+ photometric_trans_lst = [
+ getattr(photoaug, trans)(**conf) \
+ for (trans, conf) in zip(trans_lst, trans_config_lst)
+ ]
+
+ return photometric_trans_lst
+
+ def get_homo_transform(self):
+ """ Get homographic transforms (according to the config). """
+ # Get homographic transforms for image
+ homo_config = self.config["augmentation"]["homographic"]["params"]
+ if not self.config["augmentation"]["homographic"]["enable"]:
+ raise ValueError(
+ "[Error] Homographic augmentation is not enabled.")
+
+ # Parse the homographic transforms
+ image_shape = self.config["preprocessing"]["resize"]
+
+ # Compute the min_label_len from config
+ try:
+ min_label_tmp = self.config["generation"]["min_label_len"]
+ except:
+ min_label_tmp = None
+
+ # float label len => fraction
+ if isinstance(min_label_tmp, float): # Skip if not provided
+ min_label_len = min_label_tmp * min(image_shape)
+ # int label len => length in pixel
+ elif isinstance(min_label_tmp, int):
+ scale_ratio = (self.config["preprocessing"]["resize"]
+ / self.config["generation"]["image_size"][0])
+ min_label_len = (self.config["generation"]["min_label_len"]
+ * scale_ratio)
+ # if none => no restriction
+ else:
+ min_label_len = 0
+
+ # Initialize the transform
+ homographic_trans = homoaug.homography_transform(
+ image_shape, homo_config, 0, min_label_len)
+
+ return homographic_trans
+
+ def get_line_points(self, junctions, line_map, H1=None, H2=None,
+ img_size=None, warp=False):
+ """ Sample evenly points along each line segments
+ and keep track of line idx. """
+ if np.sum(line_map) == 0:
+ # No segment detected in the image
+ line_indices = np.zeros(self.config["max_pts"], dtype=int)
+ line_points = np.zeros((self.config["max_pts"], 2), dtype=float)
+ return line_points, line_indices
+
+ # Extract all pairs of connected junctions
+ junc_indices = np.array(
+ [[i, j] for (i, j) in zip(*np.where(line_map)) if j > i])
+ line_segments = np.stack([junctions[junc_indices[:, 0]],
+ junctions[junc_indices[:, 1]]], axis=1)
+ # line_segments is (num_lines, 2, 2)
+ line_lengths = np.linalg.norm(
+ line_segments[:, 0] - line_segments[:, 1], axis=1)
+
+ # Sample the points separated by at least min_dist_pts along each line
+ # The number of samples depends on the length of the line
+ num_samples = np.minimum(line_lengths // self.config["min_dist_pts"],
+ self.config["max_num_samples"])
+ line_points = []
+ line_indices = []
+ cur_line_idx = 1
+ for n in np.arange(2, self.config["max_num_samples"] + 1):
+ # Consider all lines where we can fit up to n points
+ cur_line_seg = line_segments[num_samples == n]
+ line_points_x = np.linspace(cur_line_seg[:, 0, 0],
+ cur_line_seg[:, 1, 0],
+ n, axis=-1).flatten()
+ line_points_y = np.linspace(cur_line_seg[:, 0, 1],
+ cur_line_seg[:, 1, 1],
+ n, axis=-1).flatten()
+ jitter = self.config.get("jittering", 0)
+ if jitter:
+ # Add a small random jittering of all points along the line
+ angles = np.arctan2(
+ cur_line_seg[:, 1, 0] - cur_line_seg[:, 0, 0],
+ cur_line_seg[:, 1, 1] - cur_line_seg[:, 0, 1]).repeat(n)
+ jitter_hyp = (np.random.rand(len(angles)) * 2 - 1) * jitter
+ line_points_x += jitter_hyp * np.sin(angles)
+ line_points_y += jitter_hyp * np.cos(angles)
+ line_points.append(np.stack([line_points_x, line_points_y], axis=-1))
+ # Keep track of the line indices for each sampled point
+ num_cur_lines = len(cur_line_seg)
+ line_idx = np.arange(cur_line_idx, cur_line_idx + num_cur_lines)
+ line_indices.append(line_idx.repeat(n))
+ cur_line_idx += num_cur_lines
+ line_points = np.concatenate(line_points,
+ axis=0)[:self.config["max_pts"]]
+ line_indices = np.concatenate(line_indices,
+ axis=0)[:self.config["max_pts"]]
+
+ # Warp the points if need be, and filter unvalid ones
+ # If the other view is also warped
+ if warp and H2 is not None:
+ warp_points2 = warp_points(line_points, H2)
+ line_points = warp_points(line_points, H1)
+ mask = mask_points(line_points, img_size)
+ mask2 = mask_points(warp_points2, img_size)
+ mask = mask * mask2
+ # If the other view is not warped
+ elif warp and H2 is None:
+ line_points = warp_points(line_points, H1)
+ mask = mask_points(line_points, img_size)
+ else:
+ if H1 is not None:
+ raise ValueError("[Error] Wrong combination of homographies.")
+ # Remove points that would be outside of img_size if warped by H
+ warped_points = warp_points(line_points, H1)
+ mask = mask_points(warped_points, img_size)
+ line_points = line_points[mask]
+ line_indices = line_indices[mask]
+
+ # Pad the line points to a fixed length
+ # Index of 0 means padded line
+ line_indices = np.concatenate([line_indices, np.zeros(
+ self.config["max_pts"] - len(line_indices))], axis=0)
+ line_points = np.concatenate(
+ [line_points,
+ np.zeros((self.config["max_pts"] - len(line_points), 2),
+ dtype=float)], axis=0)
+
+ return line_points, line_indices
+
+ def train_preprocessing(self, data, numpy=False):
+ """ Train preprocessing for GT data. """
+ # Fetch the corresponding entries
+ image = data["image"]
+ junctions = data["junc"][:, :2]
+ line_pos = data["Lpos"]
+ line_neg = data["Lneg"]
+ image_size = image.shape[:2]
+ # Convert junctions to pixel coordinates (from 128x128)
+ junctions[:, 0] *= image_size[0] / 128
+ junctions[:, 1] *= image_size[1] / 128
+
+ # Resize the image before photometric and homographical augmentations
+ if not(list(image_size) == self.config["preprocessing"]["resize"]):
+ # Resize the image and the point location.
+ size_old = list(image.shape)[:2] # Only H and W dimensions
+
+ image = cv2.resize(
+ image, tuple(self.config['preprocessing']['resize'][::-1]),
+ interpolation=cv2.INTER_LINEAR)
+ image = np.array(image, dtype=np.uint8)
+
+ # In HW format
+ junctions = (junctions * np.array(
+ self.config['preprocessing']['resize'], np.float)
+ / np.array(size_old, np.float))
+
+ # Convert to positive line map and negative line map (our format)
+ num_junctions = junctions.shape[0]
+ line_map_pos = self.convert_line_map(line_pos, num_junctions)
+ line_map_neg = self.convert_line_map(line_neg, num_junctions)
+
+ # Generate the line heatmap after post-processing
+ junctions_xy = np.flip(np.round(junctions).astype(np.int32), axis=1)
+ # Update image size
+ image_size = image.shape[:2]
+ heatmap_pos = get_line_heatmap(junctions_xy, line_map_pos, image_size)
+ heatmap_neg = get_line_heatmap(junctions_xy, line_map_neg, image_size)
+ # Declare default valid mask (all ones)
+ valid_mask = np.ones(image_size)
+
+ # Optionally convert the image to grayscale
+ if self.config["gray_scale"]:
+ image = (color.rgb2gray(image) * 255.).astype(np.uint8)
+
+ # Check if we need to apply augmentations
+ # In training mode => yes.
+ # In homography adaptation mode (export mode) => No
+ if self.config["augmentation"]["photometric"]["enable"]:
+ photo_trans_lst = self.get_photo_transform()
+ ### Image transform ###
+ np.random.shuffle(photo_trans_lst)
+ image_transform = transforms.Compose(
+ photo_trans_lst + [photoaug.normalize_image()])
+ else:
+ image_transform = photoaug.normalize_image()
+ image = image_transform(image)
+
+ # Check homographic augmentation
+ if self.config["augmentation"]["homographic"]["enable"]:
+ homo_trans = self.get_homo_transform()
+ # Perform homographic transform
+ outputs_pos = homo_trans(image, junctions, line_map_pos)
+ outputs_neg = homo_trans(image, junctions, line_map_neg)
+
+ # record the warped results
+ junctions = outputs_pos["junctions"] # Should be HW format
+ image = outputs_pos["warped_image"]
+ line_map_pos = outputs_pos["line_map"]
+ line_map_neg = outputs_neg["line_map"]
+ heatmap_pos = outputs_pos["warped_heatmap"]
+ heatmap_neg = outputs_neg["warped_heatmap"]
+ valid_mask = outputs_pos["valid_mask"] # Same for pos and neg
+
+ junction_map = self.junc_to_junc_map(junctions, image_size)
+
+ # Convert to tensor and return the results
+ to_tensor = transforms.ToTensor()
+ if not numpy:
+ return {
+ "image": to_tensor(image),
+ "junctions": to_tensor(junctions).to(torch.float32)[0, ...],
+ "junction_map": to_tensor(junction_map).to(torch.int),
+ "line_map_pos": to_tensor(
+ line_map_pos).to(torch.int32)[0, ...],
+ "line_map_neg": to_tensor(
+ line_map_neg).to(torch.int32)[0, ...],
+ "heatmap_pos": to_tensor(heatmap_pos).to(torch.int32),
+ "heatmap_neg": to_tensor(heatmap_neg).to(torch.int32),
+ "valid_mask": to_tensor(valid_mask).to(torch.int32)
+ }
+ else:
+ return {
+ "image": image,
+ "junctions": junctions.astype(np.float32),
+ "junction_map": junction_map.astype(np.int32),
+ "line_map_pos": line_map_pos.astype(np.int32),
+ "line_map_neg": line_map_neg.astype(np.int32),
+ "heatmap_pos": heatmap_pos.astype(np.int32),
+ "heatmap_neg": heatmap_neg.astype(np.int32),
+ "valid_mask": valid_mask.astype(np.int32)
+ }
+
+ def train_preprocessing_exported(
+ self, data, numpy=False, disable_homoaug=False,
+ desc_training=False, H1=None, H1_scale=None, H2=None, scale=1.,
+ h_crop=None, w_crop=None):
+ """ Train preprocessing for the exported labels. """
+ data = copy.deepcopy(data)
+ # Fetch the corresponding entries
+ image = data["image"]
+ junctions = data["junctions"]
+ line_map = data["line_map"]
+ image_size = image.shape[:2]
+
+ # Define the random crop for scaling if necessary
+ if h_crop is None or w_crop is None:
+ h_crop, w_crop = 0, 0
+ if scale > 1:
+ H, W = self.config["preprocessing"]["resize"]
+ H_scale, W_scale = round(H * scale), round(W * scale)
+ if H_scale > H:
+ h_crop = np.random.randint(H_scale - H)
+ if W_scale > W:
+ w_crop = np.random.randint(W_scale - W)
+
+ # Resize the image before photometric and homographical augmentations
+ if not(list(image_size) == self.config["preprocessing"]["resize"]):
+ # Resize the image and the point location.
+ size_old = list(image.shape)[:2] # Only H and W dimensions
+
+ image = cv2.resize(
+ image, tuple(self.config['preprocessing']['resize'][::-1]),
+ interpolation=cv2.INTER_LINEAR)
+ image = np.array(image, dtype=np.uint8)
+
+ # In HW format
+ junctions = (junctions * np.array(
+ self.config['preprocessing']['resize'], np.float)
+ / np.array(size_old, np.float))
+
+ # Generate the line heatmap after post-processing
+ junctions_xy = np.flip(np.round(junctions).astype(np.int32), axis=1)
+ image_size = image.shape[:2]
+ heatmap = get_line_heatmap(junctions_xy, line_map, image_size)
+
+ # Optionally convert the image to grayscale
+ if self.config["gray_scale"]:
+ image = (color.rgb2gray(image) * 255.).astype(np.uint8)
+
+ # Check if we need to apply augmentations
+ # In training mode => yes.
+ # In homography adaptation mode (export mode) => No
+ if self.config["augmentation"]["photometric"]["enable"]:
+ photo_trans_lst = self.get_photo_transform()
+ ### Image transform ###
+ np.random.shuffle(photo_trans_lst)
+ image_transform = transforms.Compose(
+ photo_trans_lst + [photoaug.normalize_image()])
+ else:
+ image_transform = photoaug.normalize_image()
+ image = image_transform(image)
+
+ # Perform the random scaling
+ if scale != 1.:
+ image, junctions, line_map, valid_mask = random_scaling(
+ image, junctions, line_map, scale,
+ h_crop=h_crop, w_crop=w_crop)
+ else:
+ # Declare default valid mask (all ones)
+ valid_mask = np.ones(image_size)
+
+ # Initialize the empty output dict
+ outputs = {}
+ # Convert to tensor and return the results
+ to_tensor = transforms.ToTensor()
+
+ # Check homographic augmentation
+ warp = (self.config["augmentation"]["homographic"]["enable"]
+ and disable_homoaug == False)
+ if warp:
+ homo_trans = self.get_homo_transform()
+ # Perform homographic transform
+ if H1 is None:
+ homo_outputs = homo_trans(
+ image, junctions, line_map, valid_mask=valid_mask)
+ else:
+ homo_outputs = homo_trans(
+ image, junctions, line_map, homo=H1, scale=H1_scale,
+ valid_mask=valid_mask)
+ homography_mat = homo_outputs["homo"]
+
+ # Give the warp of the other view
+ if H1 is None:
+ H1 = homo_outputs["homo"]
+
+ # Sample points along each line segments for the descriptor
+ if desc_training:
+ line_points, line_indices = self.get_line_points(
+ junctions, line_map, H1=H1, H2=H2,
+ img_size=image_size, warp=warp)
+
+ # Record the warped results
+ if warp:
+ junctions = homo_outputs["junctions"] # Should be HW format
+ image = homo_outputs["warped_image"]
+ line_map = homo_outputs["line_map"]
+ valid_mask = homo_outputs["valid_mask"] # Same for pos and neg
+ # heatmap = homo_outputs["warped_heatmap"]
+
+ # Optionally put warping information first.
+ if not numpy:
+ outputs["homography_mat"] = to_tensor(
+ homography_mat).to(torch.float32)[0, ...]
+ else:
+ outputs["homography_mat"] = homography_mat.astype(np.float32)
+
+ junction_map = self.junc_to_junc_map(junctions, image_size)
+
+ if not numpy:
+ outputs.update({
+ "image": to_tensor(image).to(torch.float32),
+ "junctions": to_tensor(junctions).to(torch.float32)[0, ...],
+ "junction_map": to_tensor(junction_map).to(torch.int),
+ "line_map": to_tensor(line_map).to(torch.int32)[0, ...],
+ # "heatmap": to_tensor(heatmap).to(torch.int32),
+ "valid_mask": to_tensor(valid_mask).to(torch.int32)
+ })
+ if desc_training:
+ outputs.update({
+ "line_points": to_tensor(
+ line_points).to(torch.float32)[0],
+ "line_indices": torch.tensor(line_indices,
+ dtype=torch.int)
+ })
+ else:
+ outputs.update({
+ "image": image,
+ "junctions": junctions.astype(np.float32),
+ "junction_map": junction_map.astype(np.int32),
+ "line_map": line_map.astype(np.int32),
+ # "heatmap": heatmap.astype(np.int32),
+ "valid_mask": valid_mask.astype(np.int32)
+ })
+ if desc_training:
+ outputs.update({
+ "line_points": line_points.astype(np.float32),
+ "line_indices": line_indices.astype(int)
+ })
+
+ return outputs
+
+ def preprocessing_exported_paired_desc(self, data, numpy=False, scale=1.):
+ """ Train preprocessing for paired data for the exported labels
+ for descriptor training. """
+ outputs = {}
+
+ # Define the random crop for scaling if necessary
+ h_crop, w_crop = 0, 0
+ if scale > 1:
+ H, W = self.config["preprocessing"]["resize"]
+ H_scale, W_scale = round(H * scale), round(W * scale)
+ if H_scale > H:
+ h_crop = np.random.randint(H_scale - H)
+ if W_scale > W:
+ w_crop = np.random.randint(W_scale - W)
+
+ # Sample ref homography first
+ homo_config = self.config["augmentation"]["homographic"]["params"]
+ image_shape = self.config["preprocessing"]["resize"]
+ ref_H, ref_scale = homoaug.sample_homography(image_shape,
+ **homo_config)
+
+ # Data for target view (All augmentation)
+ target_data = self.train_preprocessing_exported(
+ data, numpy=numpy, desc_training=True, H1=None, H2=ref_H,
+ scale=scale, h_crop=h_crop, w_crop=w_crop)
+
+ # Data for reference view (No homographical augmentation)
+ ref_data = self.train_preprocessing_exported(
+ data, numpy=numpy, desc_training=True, H1=ref_H,
+ H1_scale=ref_scale, H2=target_data["homography_mat"].numpy(),
+ scale=scale, h_crop=h_crop, w_crop=w_crop)
+
+ # Spread ref data
+ for key, val in ref_data.items():
+ outputs["ref_" + key] = val
+
+ # Spread target data
+ for key, val in target_data.items():
+ outputs["target_" + key] = val
+
+ return outputs
+
+ def test_preprocessing(self, data, numpy=False):
+ """ Test preprocessing for GT data. """
+ data = copy.deepcopy(data)
+ # Fetch the corresponding entries
+ image = data["image"]
+ junctions = data["junc"][:, :2]
+ line_pos = data["Lpos"]
+ line_neg = data["Lneg"]
+ image_size = image.shape[:2]
+ # Convert junctions to pixel coordinates (from 128x128)
+ junctions[:, 0] *= image_size[0] / 128
+ junctions[:, 1] *= image_size[1] / 128
+
+ # Resize the image before photometric and homographical augmentations
+ if not(list(image_size) == self.config["preprocessing"]["resize"]):
+ # Resize the image and the point location.
+ size_old = list(image.shape)[:2] # Only H and W dimensions
+
+ image = cv2.resize(
+ image, tuple(self.config['preprocessing']['resize'][::-1]),
+ interpolation=cv2.INTER_LINEAR)
+ image = np.array(image, dtype=np.uint8)
+
+ # In HW format
+ junctions = (junctions * np.array(
+ self.config['preprocessing']['resize'], np.float)
+ / np.array(size_old, np.float))
+
+ # Optionally convert the image to grayscale
+ if self.config["gray_scale"]:
+ image = (color.rgb2gray(image) * 255.).astype(np.uint8)
+
+ # Still need to normalize image
+ image_transform = photoaug.normalize_image()
+ image = image_transform(image)
+
+ # Convert to positive line map and negative line map (our format)
+ num_junctions = junctions.shape[0]
+ line_map_pos = self.convert_line_map(line_pos, num_junctions)
+ line_map_neg = self.convert_line_map(line_neg, num_junctions)
+
+ # Generate the line heatmap after post-processing
+ junctions_xy = np.flip(np.round(junctions).astype(np.int32), axis=1)
+ # Update image size
+ image_size = image.shape[:2]
+ heatmap_pos = get_line_heatmap(junctions_xy, line_map_pos, image_size)
+ heatmap_neg = get_line_heatmap(junctions_xy, line_map_neg, image_size)
+ # Declare default valid mask (all ones)
+ valid_mask = np.ones(image_size)
+
+ junction_map = self.junc_to_junc_map(junctions, image_size)
+
+ # Convert to tensor and return the results
+ to_tensor = transforms.ToTensor()
+ if not numpy:
+ return {
+ "image": to_tensor(image),
+ "junctions": to_tensor(junctions).to(torch.float32)[0, ...],
+ "junction_map": to_tensor(junction_map).to(torch.int),
+ "line_map_pos": to_tensor(
+ line_map_pos).to(torch.int32)[0, ...],
+ "line_map_neg": to_tensor(
+ line_map_neg).to(torch.int32)[0, ...],
+ "heatmap_pos": to_tensor(heatmap_pos).to(torch.int32),
+ "heatmap_neg": to_tensor(heatmap_neg).to(torch.int32),
+ "valid_mask": to_tensor(valid_mask).to(torch.int32)
+ }
+ else:
+ return {
+ "image": image,
+ "junctions": junctions.astype(np.float32),
+ "junction_map": junction_map.astype(np.int32),
+ "line_map_pos": line_map_pos.astype(np.int32),
+ "line_map_neg": line_map_neg.astype(np.int32),
+ "heatmap_pos": heatmap_pos.astype(np.int32),
+ "heatmap_neg": heatmap_neg.astype(np.int32),
+ "valid_mask": valid_mask.astype(np.int32)
+ }
+
+ def test_preprocessing_exported(self, data, numpy=False, scale=1.):
+ """ Test preprocessing for the exported labels. """
+ data = copy.deepcopy(data)
+ # Fetch the corresponding entries
+ image = data["image"]
+ junctions = data["junctions"]
+ line_map = data["line_map"]
+ image_size = image.shape[:2]
+
+ # Resize the image before photometric and homographical augmentations
+ if not(list(image_size) == self.config["preprocessing"]["resize"]):
+ # Resize the image and the point location.
+ size_old = list(image.shape)[:2] # Only H and W dimensions
+
+ image = cv2.resize(
+ image, tuple(self.config['preprocessing']['resize'][::-1]),
+ interpolation=cv2.INTER_LINEAR)
+ image = np.array(image, dtype=np.uint8)
+
+ # In HW format
+ junctions = (junctions * np.array(
+ self.config['preprocessing']['resize'], np.float)
+ / np.array(size_old, np.float))
+
+ # Optionally convert the image to grayscale
+ if self.config["gray_scale"]:
+ image = (color.rgb2gray(image) * 255.).astype(np.uint8)
+
+ # Still need to normalize image
+ image_transform = photoaug.normalize_image()
+ image = image_transform(image)
+
+ # Generate the line heatmap after post-processing
+ junctions_xy = np.flip(np.round(junctions).astype(np.int32), axis=1)
+ image_size = image.shape[:2]
+ heatmap = get_line_heatmap(junctions_xy, line_map, image_size)
+
+ # Declare default valid mask (all ones)
+ valid_mask = np.ones(image_size)
+
+ junction_map = self.junc_to_junc_map(junctions, image_size)
+
+ # Convert to tensor and return the results
+ to_tensor = transforms.ToTensor()
+ if not numpy:
+ outputs = {
+ "image": to_tensor(image),
+ "junctions": to_tensor(junctions).to(torch.float32)[0, ...],
+ "junction_map": to_tensor(junction_map).to(torch.int),
+ "line_map": to_tensor(line_map).to(torch.int32)[0, ...],
+ "heatmap": to_tensor(heatmap).to(torch.int32),
+ "valid_mask": to_tensor(valid_mask).to(torch.int32)
+ }
+ else:
+ outputs = {
+ "image": image,
+ "junctions": junctions.astype(np.float32),
+ "junction_map": junction_map.astype(np.int32),
+ "line_map": line_map.astype(np.int32),
+ "heatmap": heatmap.astype(np.int32),
+ "valid_mask": valid_mask.astype(np.int32)
+ }
+
+ return outputs
+
+ def __len__(self):
+ return self.dataset_length
+
+ def get_data_from_key(self, file_key):
+ """ Get data from file_key. """
+ # Check key exists
+ if not file_key in self.filename_dataset.keys():
+ raise ValueError("[Error] the specified key is not in the dataset.")
+
+ # Get the data paths
+ data_path = self.filename_dataset[file_key]
+ # Read in the image and npz labels (but haven't applied any transform)
+ data = self.get_data_from_path(data_path)
+
+ # Perform transform and augmentation
+ if self.mode == "train" or self.config["add_augmentation_to_all_splits"]:
+ data = self.train_preprocessing(data, numpy=True)
+ else:
+ data = self.test_preprocessing(data, numpy=True)
+
+ # Add file key to the output
+ data["file_key"] = file_key
+
+ return data
+
+ def __getitem__(self, idx):
+ """Return data
+ file_key: str, keys used to retrieve data from the filename dataset.
+ image: torch.float, C*H*W range 0~1,
+ junctions: torch.float, N*2,
+ junction_map: torch.int32, 1*H*W range 0 or 1,
+ line_map_pos: torch.int32, N*N range 0 or 1,
+ line_map_neg: torch.int32, N*N range 0 or 1,
+ heatmap_pos: torch.int32, 1*H*W range 0 or 1,
+ heatmap_neg: torch.int32, 1*H*W range 0 or 1,
+ valid_mask: torch.int32, 1*H*W range 0 or 1
+ """
+ # Get the corresponding datapoint and contents from filename dataset
+ file_key = self.datapoints[idx]
+ data_path = self.filename_dataset[file_key]
+ # Read in the image and npz labels (but haven't applied any transform)
+ data = self.get_data_from_path(data_path)
+ # Also load the exported labels if not using the official ground truth
+ if not self.gt_source == "official":
+ with h5py.File(self.gt_source, "r") as f:
+ exported_label = parse_h5_data(f[file_key])
+
+ data["junctions"] = exported_label["junctions"]
+ data["line_map"] = exported_label["line_map"]
+ data["junctions"] = data["junctions"][:,::-1]
+ # Perform transform and augmentation
+ return_type = self.config.get("return_type", "single")
+ if (self.mode == "train"
+ or self.config["add_augmentation_to_all_splits"]):
+ # Perform random scaling first
+ if self.config["augmentation"]["random_scaling"]["enable"]:
+ scale_range = self.config["augmentation"]["random_scaling"]["range"]
+ # Decide the scaling
+ scale = np.random.uniform(min(scale_range), max(scale_range))
+ else:
+ scale = 1.
+ if self.gt_source == "official":
+ data = self.train_preprocessing(data)
+ else:
+ if return_type == "paired_desc":
+ data = self.preprocessing_exported_paired_desc(
+ data, scale=scale)
+ else:
+ data = self.train_preprocessing_exported(data,
+ scale=scale)
+ else:
+ if self.gt_source == "official":
+ data = self.test_preprocessing(data)
+ elif return_type == "paired_desc":
+ data = self.preprocessing_exported_paired_desc(data)
+ else:
+ data = self.test_preprocessing_exported(data)
+
+ # Add file key to the output
+ data["file_key"] = file_key
+
+ return data
+
+ ########################
+ ## Some other methods ##
+ ########################
+ def _check_dataset_cache(self):
+ """ Check if dataset cache exists. """
+ cache_file_path = os.path.join(self.cache_path, self.cache_name)
+ if os.path.exists(cache_file_path):
+ return True
+ else:
+ return False
+
+ # Check if the repeatability cache dataset exists
+ def _check_rep_eval_dataset_cache(self, split):
+ if split == "i":
+ cache_file_path = os.path.join(self.cache_path, self.rep_i_cache_name)
+ else:
+ cache_file_path = os.path.join(self.cache_path, self.rep_v_cache_name)
+
+ return os.path.exists(cache_file_path)
\ No newline at end of file
diff --git a/scalelsd/ssl/datasets/yorkurban_dataset.py b/scalelsd/ssl/datasets/yorkurban_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..de6c460957fe4b2d7f7c9c9f16af9ecc04e05b60
--- /dev/null
+++ b/scalelsd/ssl/datasets/yorkurban_dataset.py
@@ -0,0 +1,479 @@
+from torch.utils.data import Dataset
+import os
+import math
+from tqdm import tqdm
+from skimage.io import imread
+from skimage import color
+import PIL
+import numpy as np
+import h5py
+import cv2
+import pickle
+from .synthetic_util import get_line_heatmap
+from torchvision import transforms
+import torch
+import torch.utils.data.dataloader as torch_loader
+# Augmentation libs
+from ..config.project_config import Config as cfg
+from .transforms import photometric_transforms as photoaug
+from .transforms import homographic_transforms as homoaug
+# Some visualization tools
+from ..misc.visualize_util import plot_junctions, plot_line_segments
+# Some data parsing tools
+from ..misc.train_utils import parse_h5_data
+# Inherit from private dataset
+# from dataset.private_dataset import PrivateDataset
+
+
+# Implements the customized collate_fn for yorkurban dataset
+def yorkurban_collate_fn(batch):
+ batch_keys = ["image", "junction_map", "valid_mask", "heatmap",
+ "heatmap_pos", "heatmap_neg", "homography"]
+ list_keys = ["junctions", "line_map", "line_map_pos", "line_map_neg", "file_key",
+ "aux_junctions", "aux_line_map"]
+
+ outputs = {}
+ for data_key in batch[0].keys():
+ batch_match = sum([_ in data_key for _ in batch_keys])
+ list_match = sum([_ in data_key for _ in list_keys])
+ # print(batch_match, list_match)
+ if batch_match > 0 and list_match == 0:
+ outputs[data_key] = torch_loader.default_collate([b[data_key] for b in batch])
+ elif batch_match == 0 and list_match > 0:
+ outputs[data_key] = [b[data_key] for b in batch]
+ elif batch_match == 0 and list_match == 0:
+ continue
+ else:
+ raise ValueError("[Error] A key matches batch keys and list keys simultaneously.")
+
+ return outputs
+
+# The processed wireframe.
+class YorkUrbanDataset(Dataset):
+ # Initialize the dataset
+ def __init__(self, mode="test", config=None):
+ super(YorkUrbanDataset, self).__init__()
+ # Check mode => "train", "val", "test
+ if not mode in ["test"]:
+ raise ValueError("[Error] Unknown mode for york urban dataset. Only 'test' mode is available.")
+ self.mode = mode
+
+ self.config = config
+
+ # Get cache setting
+ self.dataset_name = self.get_dataset_name()
+ self.cache_name = self.get_cache_name()
+ self.cache_path = cfg.yorkurban_cache_path
+
+ # Get the filename dataset
+ print("[Info] Initializing york urban dataset...")
+ self.filename_dataset, self.datapoints = self.get_filename_dataset()
+ # Get dataset length
+ self.dataset_length = len(self.datapoints)
+
+ # Get repeatability evaluation set
+ if self.mode == "test" and self.config.get("evaluation", None) is not None:
+ # Get the cache name
+ tmp = self.cache_name.split(self.mode)
+ self.rep_i_cache_name = tmp[0] + self.mode + "_rep_i" + tmp[1]
+ self.rep_v_cache_name = tmp[0] + self.mode + "_rep_v" + tmp[1]
+
+ # Get the repeatability config
+ self.rep_config = self.config["evaluation"]["repeatability"]
+
+ self.rep_eval_dataset = self.construct_rep_eval_dataset()
+ self.rep_eval_datapoints = self.get_rep_eval_datapoints()
+
+ # Print some info
+ print("[Info] Successfully initialized dataset")
+ print("\t Name: yorkurban")
+ print("\t Mode: %s" %(self.mode))
+ print("\t Gt: %s" %(self.config.get("gt_source_%s"%(self.mode), "official")))
+ print("\t Counts: %d" %(self.dataset_length))
+ print("----------------------------------------")
+
+ def get_filename_dataset(self):
+ # Get the path to the dataset
+ if self.mode == "train":
+ raise NotImplementedError
+ elif self.mode == "test":
+ dataset_path = os.path.join(cfg.yorkurban_dataroot)
+ # Get paths to all image files
+ folder_lst = sorted([os.path.join(dataset_path, _) for _ in os.listdir(dataset_path) \
+ if os.path.isdir(os.path.join(dataset_path, _))])
+ folder_lst = folder_lst[:-1]
+ #folder_lst = [f for f in folder_lst if f.startswith('P')]
+ image_paths = []
+ for folder in folder_lst:
+ image_path = [os.path.join(folder, _) for _ in os.listdir(folder) \
+ if os.path.splitext(_)[-1] == ".jpg" or os.path.splitext(_)[-1] == ".png"]
+ image_paths += image_path
+
+ # Verify all the images and labels exist
+ for idx in range(len(image_paths)):
+ image_path = image_paths[idx]
+ if not os.path.exists(image_path):
+ raise ValueError("[Error] The image does not exist. %s"%(image_path))
+
+ # Construct the filename dataset
+ num_pad = int(math.ceil(math.log10(len(image_paths))) + 1)
+ filename_dataset = {}
+ for idx in range(len(image_paths)):
+ # Get the file key
+ key = self.get_padded_filename(num_pad, idx)
+
+ filename_dataset[key] = {
+ "image": image_paths[idx]
+ }
+
+ # Get the datapoints
+ datapoints = list(sorted(filename_dataset.keys()))
+
+ return filename_dataset, datapoints
+
+ # Get the padded filename using adaptive padding
+ @staticmethod
+ def get_padded_filename(num_pad, idx):
+ file_len = len("%d" % (idx))
+ filename = "0" * (num_pad - file_len) + "%d" % (idx)
+
+ return filename
+
+ # Get dataset name from dataset config / default config
+ def get_dataset_name(self):
+ if self.config["dataset_name"] is None:
+ dataset_name = "yorkurban_dataset" + f"_{self.mode}"
+ else:
+ dataset_name = self.config["dataset_name"] + f"_{self.mode}"
+
+ return dataset_name
+
+ # Get cache name from dataset config / default config
+ def get_cache_name(self):
+ if self.config["dataset_name"] is None:
+ dataset_name = "yorkurban_dataset" + f"_{self.mode}"
+ else:
+ dataset_name = self.config["dataset_name"] + f"_{self.mode}"
+ # Compose cache name
+ cache_name = dataset_name + "_cache.pkl"
+
+ return cache_name
+
+ ###########################################
+ ## Repeatability evaluation related APIs ##
+ ###########################################
+ # Construct repeatability evaluation dataset (from scratch or from cache)
+ def construct_rep_eval_dataset(self):
+ rep_eval_dataset = {}
+ # Check if viewpoint and illumination cache exists
+ if self.rep_config["photometric"]["enable"]:
+ if self._check_rep_eval_dataset_cache(split="i"):
+ print("\t Found repeatability illumination cache %s at %s"%(self.rep_i_cache_name, self.cache_path))
+ print("\t Load repeatability illumination cache...")
+ rep_i_keymap, rep_i_dataset_name = self.get_rep_eval_dataset_from_cache(split="i")
+ else:
+ print("\t Can't find repeatability illumination cache ...")
+ print("\t Create repeatability illumination dataset from scratch...")
+ rep_i_keymap, rep_i_dataset_name = self.get_rep_eval_dataset(split="i")
+ print("\t Create filename dataset cache...")
+ self.create_rep_eval_dataset_cache("i", rep_i_keymap, rep_i_dataset_name)
+ else:
+ rep_i_keymap = None
+ rep_i_dataset_name = None
+
+ rep_eval_dataset["illumination"] = {
+ "keymap": rep_i_keymap,
+ "dataset_name": rep_i_dataset_name
+ }
+
+ if self.rep_config["homographic"]["enable"]:
+ if self._check_rep_eval_dataset_cache(split="v"):
+ print("\t Found repeatability viewpoint cache %s at %s"%(self.rep_v_cache_name, self.cache_path))
+ print("\t Load repeatability viewpoint cache...")
+ rep_v_keymap, rep_v_dataset_name = self.get_rep_eval_dataset_from_cache(split="v")
+ else:
+ print("\t Can't find repeatability viewpoint cache ...")
+ print("\t Create repeatability viewpoint dataset from scratch...")
+ rep_v_keymap, rep_v_dataset_name = self.get_rep_eval_dataset(split="v")
+ print("\t Create filename dataset cache...")
+ self.create_rep_eval_dataset_cache("v", rep_v_keymap, rep_v_dataset_name)
+ else:
+ rep_v_keymap = None
+ rep_v_dataset_name = None
+
+ rep_eval_dataset["viewpoint"] = {
+ "keymap": rep_v_keymap,
+ "dataset_name": rep_v_dataset_name
+ }
+
+ return rep_eval_dataset
+
+ # Create filename dataset cache for faster initialization
+ def create_rep_eval_dataset_cache(self, split, keymap, dataset_name):
+ # Check cache path exists
+ if not os.path.exists(self.cache_path):
+ os.makedirs(self.cache_path)
+
+ if split == "i":
+ cache_file_path = os.path.join(self.cache_path, self.rep_i_cache_name)
+ elif split == "v":
+ cache_file_path = os.path.join(self.cache_path, self.rep_v_cache_name)
+ else:
+ raise ValueError("[Error] Unknown split for repeatability evaluation.")
+
+ data = {
+ "keymap": keymap,
+ "dataset_name": dataset_name
+ }
+ with open(cache_file_path, "wb") as f:
+ pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)
+
+ # Get filename dataset from cache
+ def get_rep_eval_dataset_from_cache(self, split):
+ # Load from pkl cache
+ if split == "i":
+ cache_file_path = os.path.join(self.cache_path, self.rep_i_cache_name)
+ elif split == "v":
+ cache_file_path = os.path.join(self.cache_path, self.rep_v_cache_name)
+ else:
+ raise ValueError("[Error] Unknown split for repeatability evaluation.")
+
+ with open(cache_file_path, "rb") as f:
+ data = pickle.load(f)
+
+ return data["keymap"], data["dataset_name"]
+
+ # Initialize the repeatability evaluation dataset from scratch
+ def get_rep_eval_dataset(self, split):
+ image_shape = self.config["preprocessing"]["resize"]
+
+ # Initialize the illumination set
+ if split == "i":
+ # Set the random seed before continuing
+ seed = self.rep_config["seed"]
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+
+ raise NotImplementedError
+
+ # Initialize the viewpoint set
+ elif split == "v":
+ # Set the random seed before continuing
+ seed = self.rep_config["seed"]
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+
+ v_keymap = {}
+ # Get the name for the output h5 dataset
+ v_dataset_name = self.rep_v_cache_name.split(".pkl")[0] + ".h5"
+ v_dataset_path = os.path.join(self.cache_path, v_dataset_name)
+ with h5py.File(v_dataset_path, "w") as f:
+ # Iterate through all the file_key in test set
+ for idx, key in enumerate(tqdm(list(self.filename_dataset.keys()), ascii=True)):
+ # Sample N random homography
+ file_key_lst = []
+ for i in range(self.rep_config["homographic"]["num_samples"]):
+ file_key = key + "_" + str(i)
+
+ # Sample a random homography
+ homo_mat, _ = homoaug.sample_homography(image_shape,
+ **self.rep_config["homographic"]["params"])
+
+ file_key_lst.append(file_key)
+ f.create_dataset(file_key, data=homo_mat, compression="gzip")
+
+ v_keymap[key] = file_key_lst
+
+ return v_keymap, v_dataset_name
+
+ else:
+ raise ValueError("[Error] Unknow split for repeatability evaluation.")
+
+ # Convert ref image and warped images to list of evaluation pairs
+ def get_rep_eval_datapoints(self):
+ datapoints = {
+ "illumination": [],
+ "viewpoint": []
+ }
+
+ # Iterate through all the ref image
+ if self.rep_eval_dataset["illumination"]["keymap"] is not None:
+ for ref_key in sorted(self.rep_eval_dataset["illumination"]["keymap"].keys()):
+ pair_lst = [[ref_key, _] for _ in self.rep_eval_dataset["illumination"]["keymap"][ref_key]]
+ datapoints["illumination"] += pair_lst
+
+ if self.rep_eval_dataset["viewpoint"]["keymap"] is not None:
+ for ref_key in sorted(self.rep_eval_dataset["viewpoint"]["keymap"].keys()):
+ pair_lst = [[ref_key, _] for _ in self.rep_eval_dataset["viewpoint"]["keymap"][ref_key]]
+ datapoints["viewpoint"] += pair_lst
+
+ return datapoints
+
+ # Check if the repeatability cache dataset exists
+ def _check_rep_eval_dataset_cache(self, split):
+ if split == "i":
+ cache_file_path = os.path.join(self.cache_path, self.rep_i_cache_name)
+ else:
+ cache_file_path = os.path.join(self.cache_path, self.rep_v_cache_name)
+
+ return os.path.exists(cache_file_path)
+
+ ###########################################
+ ## Repeatability evaluation related APIs ##
+ ###########################################
+ # Get the corresponding data according to the "index in rep_eval_datapoints".
+ def get_rep_eval_data(self, split, idx):
+ assert split in ["viewpoint", "illumination"]
+ datapoint = self.rep_eval_datapoints[split][idx]
+
+ # Get reference image
+ ref_key = datapoint[0]
+ # Get the data paths
+ data_path = self.filename_dataset[ref_key]
+ # Read in the image and npz labels (but haven't applied any transform)
+ image = imread(data_path["image"])
+
+ # Resize the image before photometric and homographical augmentations
+ image_size = image.shape[:2]
+ if not(list(image_size) == self.config["preprocessing"]["resize"]):
+ # Resize the image and the point location.
+ size_old = list(image.shape)[:2] # Only H and W dimensions
+
+ image = cv2.resize(image, tuple(self.config['preprocessing']['resize'][::-1]),
+ interpolation=cv2.INTER_LINEAR)
+ image = np.array(image, dtype=np.uint8)
+
+ # Optionally convert the image to grayscale
+ if self.config["gray_scale"]:
+ image = (color.rgb2gray(image) *255.).astype(np.uint8)
+
+ image_transform = photoaug.normalize_image()
+ image = image_transform(image)
+
+ # Get target image
+ if split == "viewpoint":
+ target_key = datapoint[1]
+ dataset_path = os.path.join(self.cache_path, self.rep_eval_dataset[split]["dataset_name"])
+ with h5py.File(dataset_path, "r") as f:
+ homo_mat = np.array(f[target_key])
+
+ # Warp the image
+ target_size = (image.shape[1], image.shape[0])
+ target_image = cv2.warpPerspective(image, homo_mat, target_size,
+ flags=cv2.INTER_LINEAR)
+
+ else:
+ raise NotImplementedError
+
+ # Convert to tensor and return the results
+ to_tensor = transforms.ToTensor()
+
+ return {
+ "ref_image": to_tensor(image),
+ "ref_key": ref_key,
+ "target_image": to_tensor(target_image),
+ "target_key": target_key,
+ "homo_mat": homo_mat
+ }
+
+ ############################################
+ ## Pytorch and preprocessing related APIs ##
+ ############################################
+ # Get the length of the dataset
+ def __len__(self):
+ return self.dataset_length
+
+ # Get data from the information from filename dataset
+ @staticmethod
+ def get_data_from_path(data_path):
+ output = {}
+
+ # Get image data
+ image_path = data_path["image"]
+ image = imread(image_path)
+ output["image"] = image
+
+ return output
+
+ # The test preprocessing
+ def test_preprocessing(self, data, numpy=False):
+ # Fetch the corresponding entries
+ image = data["image"]
+ image_size = image.shape[:2]
+
+ # Resize the image before photometric and homographical augmentations
+ if not(list(image_size) == self.config["preprocessing"]["resize"]):
+ # Resize the image and the point location.
+ size_old = list(image.shape)[:2] # Only H and W dimensions
+
+ image = cv2.resize(image, tuple(self.config['preprocessing']['resize'][::-1]),
+ interpolation=cv2.INTER_LINEAR)
+ image = np.array(image, dtype=np.uint8)
+
+ # Optionally convert the image to grayscale
+ if self.config["gray_scale"]:
+ image = (color.rgb2gray(image) *255.).astype(np.uint8)
+
+ # Still need to normalize image
+ image_transform = photoaug.normalize_image()
+ image = image_transform(image)
+
+ # Update image size
+ image_size = image.shape[:2]
+ valid_mask = np.ones(image_size)
+
+ # Convert to tensor and return the results
+ to_tensor = transforms.ToTensor()
+ if not numpy:
+ return {
+ "image": to_tensor(image),
+ "valid_mask": to_tensor(valid_mask).to(torch.int32)
+ }
+ else:
+ return {
+ "image": image,
+ "valid_mask": valid_mask.astype(np.int32)
+ }
+
+ # Define the getitem method
+ def __getitem__(self, idx):
+ """Return data
+ file_key: str, keys used to retrieve certain data from the filename dataset.
+ image: torch.float, C*H*W range 0~1,
+ valid_mask: torch.int32, 1*H*W range 0 or 1
+ """
+ # Get the corresponding datapoint and get contents from filename dataset
+ file_key = self.datapoints[idx]
+ data_path = self.filename_dataset[file_key]
+ # Read in the image and npz labels (but haven't applied any transform)
+ data = self.get_data_from_path(data_path)
+
+ # Perform transform and augmentation
+ if self.mode == "train" or self.config["add_augmentation_to_all_splits"]:
+ raise NotImplementedError
+ else:
+ data = self.test_preprocessing(data)
+
+ # Add file key to the output
+ data["file_key"] = file_key
+
+ return data
+
+if __name__ == "__main__":
+ import sys
+ import yaml
+ import matplotlib
+ import matplotlib.pyplot as plt
+ plt.switch_backend("TkAgg")
+ from torch.utils.data import DataLoader
+ sys.path.append("../")
+
+ # Load configuration file
+ with open("./config/yorkurban_dataset_config.yaml", "r") as f:
+ config = yaml.safe_load(f)
+
+ config["add_augmentation_to_all_splits"] = False
+
+ # Initialize the dataset
+ test_dataset = YorkUrbanDataset(mode="test", config=config)
+ import ipdb; ipdb.set_trace()
\ No newline at end of file
diff --git a/scalelsd/ssl/misc/__init__.py b/scalelsd/ssl/misc/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/scalelsd/ssl/misc/geometry_utils.py b/scalelsd/ssl/misc/geometry_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d453be93a9f4e443da8f69efd3753b94354d381
--- /dev/null
+++ b/scalelsd/ssl/misc/geometry_utils.py
@@ -0,0 +1,282 @@
+import numpy as np
+import torch
+
+from shapely.geometry.polygon import LinearRing
+
+try:
+ from pycolmap import image_to_world, world_to_image
+except:
+ pass
+
+### Point-related utils
+
+# Warp a list of points using a homography
+def warp_points(points, homography):
+ # Convert to homogeneous and in xy format
+ new_points = np.concatenate([points[..., [1, 0]],
+ np.ones_like(points[..., :1])], axis=-1)
+ # Warp
+ new_points = (homography @ new_points.T).T
+ # Convert back to inhomogeneous and hw format
+ new_points = new_points[..., [1, 0]] / new_points[..., 2:]
+ return new_points
+
+
+# Mask out the points that are outside of img_size
+def mask_points(points, img_size):
+ mask = ((points[..., 0] >= 0)
+ & (points[..., 0] < img_size[0])
+ & (points[..., 1] >= 0)
+ & (points[..., 1] < img_size[1]))
+ return mask
+
+
+# Convert a tensor [N, 2] or batched tensor [B, N, 2] of N keypoints into
+# a grid in [-1, 1]Β² that can be used in torch.nn.functional.interpolate
+def keypoints_to_grid(keypoints, img_size):
+ n_points = keypoints.size()[-2]
+ device = keypoints.device
+ grid_points = keypoints.float() * 2. / torch.tensor(
+ img_size, dtype=torch.float, device=device) - 1.
+ grid_points = grid_points[..., [1, 0]].view(-1, n_points, 1, 2)
+ return grid_points
+
+
+# Return a 2D matrix indicating the local neighborhood of each point
+# for a given threshold and two lists of corresponding keypoints
+def get_dist_mask(kp0, kp1, valid_mask, dist_thresh):
+ b_size, n_points, _ = kp0.size()
+ dist_mask0 = torch.norm(kp0.unsqueeze(2) - kp0.unsqueeze(1), dim=-1)
+ dist_mask1 = torch.norm(kp1.unsqueeze(2) - kp1.unsqueeze(1), dim=-1)
+ dist_mask = torch.min(dist_mask0, dist_mask1)
+ dist_mask = dist_mask <= dist_thresh
+ dist_mask = dist_mask.repeat(1, 1, b_size).reshape(b_size * n_points,
+ b_size * n_points)
+ dist_mask = dist_mask[valid_mask, :][:, valid_mask]
+ return dist_mask
+
+
+### Line-related utils
+
+# Sample n points along lines of shape (num_lines, 2, 2)
+def sample_line_points(lines, n):
+ line_points_x = np.linspace(lines[:, 0, 0], lines[:, 1, 0], n, axis=-1)
+ line_points_y = np.linspace(lines[:, 0, 1], lines[:, 1, 1], n, axis=-1)
+ line_points = np.stack([line_points_x, line_points_y], axis=2)
+ return line_points
+
+
+# Return a mask of the valid lines that are within a valid mask of an image
+def mask_lines(lines, valid_mask):
+ h, w = valid_mask.shape
+ int_lines = np.clip(np.round(lines).astype(int), 0, [h - 1, w - 1])
+ h_valid = valid_mask[int_lines[:, 0, 0], int_lines[:, 0, 1]]
+ w_valid = valid_mask[int_lines[:, 1, 0], int_lines[:, 1, 1]]
+ valid = h_valid & w_valid
+ return valid
+
+
+# Return a 2D matrix indicating for each pair of points
+# if they are on the same line or not
+def get_common_line_mask(line_indices, valid_mask):
+ b_size, n_points = line_indices.shape
+ common_mask = line_indices[:, :, None] == line_indices[:, None, :]
+ common_mask = common_mask.repeat(1, 1, b_size).reshape(b_size * n_points,
+ b_size * n_points)
+ common_mask = common_mask[valid_mask, :][:, valid_mask]
+ return common_mask
+
+
+# Compute the distances between two sets of lines using the sAP distance
+def get_sAP_line_distance(warped_ref_line_seg, target_line_seg):
+ dist = (((warped_ref_line_seg[:, None, :, None]
+ - target_line_seg[:, None]) ** 2).sum(-1)) ** 0.5
+ dist = np.minimum(
+ dist[:, :, 0, 0] + dist[:, :, 1, 1],
+ dist[:, :, 0, 1] + dist[:, :, 1, 0]
+ )
+ return dist
+
+
+# Given a list of line segments and a list of points (2D or 3D coordinates),
+# compute the orthogonal projection of all points on all lines.
+# This returns the 1D coordinates of the projection on the line,
+# as well as the list of orthogonal distances.
+def project_point_to_line(line_segs, points):
+ # Compute the 1D coordinate of the points projected on the line
+ dir_vec = (line_segs[:, 1] - line_segs[:, 0])[:, None]
+ coords1d = (((points[None] - line_segs[:, None, 0]) * dir_vec).sum(axis=2)
+ / np.linalg.norm(dir_vec, axis=2) ** 2)
+ # coords1d is of shape (n_lines, n_points)
+
+ # Compute the orthogonal distance of the points to each line
+ projection = line_segs[:, None, 0] + coords1d[:, :, None] * dir_vec
+ dist_to_line = np.linalg.norm(projection - points[None], axis=2)
+
+ return coords1d, dist_to_line
+
+
+# Given a list of segments parameterized by the 1D coordinate of the endpoints
+# compute the overlap with the segment [0, 1]
+def get_segment_overlap(seg_coord1d):
+ seg_coord1d = np.sort(seg_coord1d, axis=-1)
+ overlap = ((seg_coord1d[..., 1] > 0) * (seg_coord1d[..., 0] < 1)
+ * (np.minimum(seg_coord1d[..., 1], 1)
+ - np.maximum(seg_coord1d[..., 0], 0)))
+ return overlap
+
+
+# Compute the symmetrical orthogonal line distance between two sets of lines
+# and the average overlapping ratio of both lines.
+# Enforce a high line distance for small overlaps.
+# This is compatible for nD objects (e.g. both lines in 2D or 3D).
+def get_overlap_orth_line_dist(line_seg1, line_seg2, min_overlap=0.5):
+ n_lines1, n_lines2 = len(line_seg1), len(line_seg2)
+
+ # Compute the average orthogonal line distance
+ coords_2_on_1, line_dists2 = project_point_to_line(
+ line_seg1, line_seg2.reshape(n_lines2 * 2, -1))
+ line_dists2 = line_dists2.reshape(n_lines1, n_lines2, 2).sum(axis=2)
+ coords_1_on_2, line_dists1 = project_point_to_line(
+ line_seg2, line_seg1.reshape(n_lines1 * 2, -1))
+ line_dists1 = line_dists1.reshape(n_lines2, n_lines1, 2).sum(axis=2)
+ line_dists = (line_dists2 + line_dists1.T) / 2
+
+ # Compute the average overlapping ratio
+ coords_2_on_1 = coords_2_on_1.reshape(n_lines1, n_lines2, 2)
+ overlaps1 = get_segment_overlap(coords_2_on_1)
+ coords_1_on_2 = coords_1_on_2.reshape(n_lines2, n_lines1, 2)
+ overlaps2 = get_segment_overlap(coords_1_on_2).T
+ overlaps = (overlaps1 + overlaps2) / 2
+
+ # Enforce a max line distance for line segments with small overlap
+ low_overlaps = overlaps < min_overlap
+ line_dists[low_overlaps] = np.amax(line_dists)
+ return line_dists
+
+
+### 3D geometry utils
+
+# Convert from quaternions to rotation matrix
+def qvec2rotmat(qvec):
+ return np.array([
+ [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
+ 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
+ 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],
+ [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
+ 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
+ 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],
+ [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
+ 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
+ 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])
+
+
+# Convert a rotation matrix to quaternions
+def rotmat2qvec(R):
+ Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
+ K = np.array([
+ [Rxx - Ryy - Rzz, 0, 0, 0],
+ [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
+ [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
+ [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0
+ eigvals, eigvecs = np.linalg.eigh(K)
+ qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
+ if qvec[0] < 0:
+ qvec *= -1
+ return qvec
+
+
+# Read the camera intrinsics from a file in COLMAP format
+def read_cameras(camera_file, scale_factor=None):
+ with open(camera_file, 'r') as f:
+ raw_cameras = f.read().rstrip().split('\n')
+ raw_cameras = raw_cameras[3:]
+ cameras = []
+ for c in raw_cameras:
+ data = c.split(' ')
+ cameras.append({
+ "model": data[1],
+ "width": int(data[2]),
+ "height": int(data[3]),
+ "params": np.array(list(map(float, data[4:])))})
+
+ # Optionally scale the intrinsics if the image are resized
+ if scale_factor is not None:
+ cameras = [scale_intrinsics(c, scale_factor) for c in cameras]
+ return cameras
+
+
+# Adapt the camera intrinsics to an image resize
+def scale_intrinsics(intrinsics, scale_factor):
+ new_intrinsics = {"model": intrinsics["model"],
+ "width": int(intrinsics["width"] * scale_factor + 0.5),
+ "height": int(intrinsics["height"] * scale_factor + 0.5)
+ }
+ params = intrinsics["params"]
+ # Adapt the focal length
+ params[:2] *= scale_factor
+ # Adapt the principal point
+ params[2:4] = (params[2:4] * scale_factor + 0.5) - 0.5
+ new_intrinsics["params"] = params
+ return new_intrinsics
+
+
+# Project points from 2D to 3D, in (x, y, z) format
+def project_2d_to_3d(points, depth, T_local_to_world, intrinsics):
+ # Warp to world homogeneous coordinates
+ world_points = image_to_world(points[:, [1, 0]],
+ intrinsics)['world_points']
+ world_points *= depth[:, None]
+ world_points = np.concatenate([world_points, depth[:, None],
+ np.ones((len(depth), 1))], axis=1)
+
+ # Warp to the world coordinates
+ world_points = (T_local_to_world @ world_points.T).T
+ world_points = world_points[:, :3] / world_points[:, 3:]
+ return world_points
+
+
+# Project points from 3D in (x, y, z) format to 2D
+def project_3d_to_2d(points, T_world_to_local, intrinsics):
+ norm_points = np.concatenate([points, np.ones((len(points), 1))], axis=1)
+ norm_points = (T_world_to_local @ norm_points.T).T
+ norm_points = norm_points[:, :3] / norm_points[:, 3:]
+ norm_points = norm_points[:, :2] / norm_points[:, 2:]
+ image_points = world_to_image(norm_points, intrinsics)
+ image_points = np.stack(image_points['image_points'])[:, [1, 0]]
+ return image_points
+
+
+### Line-ellipse intersection
+
+# Sample n points along ellipses, given as a list of
+# tuples (x, c, a, b, theta). Then approximates the
+# ellipse with the output polygon.
+def ellipse_polyline(ellipses, n=100):
+ t = np.linspace(0, 2*np.pi, n, endpoint=False)
+ st = np.sin(t)
+ ct = np.cos(t)
+ result = []
+ for x0, y0, a, b, angle in ellipses:
+ angle = np.deg2rad(angle)
+ sa = np.sin(angle)
+ ca = np.cos(angle)
+ p = np.empty((n, 2))
+ p[:, 0] = x0 + a * ca * ct - b * sa * st
+ p[:, 1] = y0 + a * sa * ct + b * ca * st
+ result.append(p)
+ return result
+
+
+# Compute the intersections between an ellipse a and a line.
+def intersect_line_ellipse(a, line):
+ ea = LinearRing(a)
+ mp = ea.intersection(line)
+ if mp.is_empty:
+ return np.empty((0, 2))
+ elif mp.geom_type == 'Point':
+ return np.array([[mp.x, mp.y]])
+ elif mp.geom_type == 'MultiPoint':
+ return np.stack([[p.x for p in mp], [p.y for p in mp]], axis=-1)
+ else:
+ raise ValueError('Impossible geometry: ' + mp.geom_type)
\ No newline at end of file
diff --git a/scalelsd/ssl/misc/train_utils.py b/scalelsd/ssl/misc/train_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d218c1521be2ae1abfc6e945e0f8f165c72f9241
--- /dev/null
+++ b/scalelsd/ssl/misc/train_utils.py
@@ -0,0 +1,56 @@
+"""
+This file contains some useful functions for train / val.
+"""
+import os
+import numpy as np
+import torch
+import random
+from scalelsd.ssl.models.detector import ScaleLSD
+
+
+################
+## HDF5 utils ##
+################
+def parse_h5_data(h5_data):
+ """ Parse h5 dataset. """
+ output_data = {}
+ for key in h5_data.keys():
+ output_data[key] = np.array(h5_data[key])
+
+ return output_data
+
+
+def fix_seeds(random_seed):
+ random.seed(random_seed)
+ np.random.seed(random_seed)
+ os.environ['PYTHONHASHSEED'] = str(random_seed)
+ torch.manual_seed(random_seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed(random_seed)
+ torch.cuda.manual_seed_all(random_seed) # if use multi-GPU
+
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+ # torch.backends.cudnn.allow_tf32 = args.tf32
+ # torch.backends.cuda.matmul.allow_tf32 = args.tf32
+ # torch.backends.cudnn.deterministic = args.dtm
+
+ # os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
+ # torch.use_deterministic_algorithms(True)
+
+
+def load_scalelsd_model(ckpt_path, device='cuda'):
+ """load model"""
+ use_layer_scale = False if 'v1' in ckpt_path else True
+
+ model = ScaleLSD(gray_scale=True, use_layer_scale=use_layer_scale)
+ model = model.eval().to(device)
+ state_dict = torch.load(ckpt_path, map_location='cpu')
+ try:
+ model.load_state_dict(state_dict['model_state'])
+ except:
+ model.load_state_dict(state_dict)
+
+ return model
+
diff --git a/scalelsd/ssl/misc/visualize_util.py b/scalelsd/ssl/misc/visualize_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..a72ce08c4685d7c6188e0fd5eb05937e550e6e2f
--- /dev/null
+++ b/scalelsd/ssl/misc/visualize_util.py
@@ -0,0 +1,526 @@
+""" Organize some frequently used visualization functions. """
+import cv2
+import numpy as np
+import matplotlib
+import matplotlib.pyplot as plt
+import copy
+import seaborn as sns
+
+
+# Plot junctions onto the image (return a separate copy)
+def plot_junctions(input_image, junctions, junc_size=3, color=None):
+ """
+ input_image: can be 0~1 float or 0~255 uint8.
+ junctions: Nx2 or 2xN np array.
+ junc_size: the size of the plotted circles.
+ """
+ # Create image copy
+ image = copy.copy(input_image)
+ # Make sure the image is converted to 255 uint8
+ if image.dtype == np.uint8:
+ pass
+ # A float type image ranging from 0~1
+ elif image.dtype in [np.float32, np.float64, np.float] and image.max() <= 2.:
+ image = (image * 255.).astype(np.uint8)
+ # A float type image ranging from 0.~255.
+ elif image.dtype in [np.float32, np.float64, np.float] and image.mean() > 10.:
+ image = image.astype(np.uint8)
+ else:
+ raise ValueError("[Error] Unknown image data type. Expect 0~1 float or 0~255 uint8.")
+
+ # Check whether the image is single channel
+ if len(image.shape) == 2 or ((len(image.shape) == 3) and (image.shape[-1] == 1)):
+ # Squeeze to H*W first
+ image = image.squeeze()
+
+ # Stack to channle 3
+ image = np.concatenate([image[..., None] for _ in range(3)], axis=-1)
+
+ # Junction dimensions should be N*2
+ if not len(junctions.shape) == 2:
+ raise ValueError("[Error] junctions should be 2-dim array.")
+
+ # Always convert to N*2
+ if junctions.shape[-1] != 2:
+ if junctions.shape[0] == 2:
+ junctions = junctions.T
+ else:
+ raise ValueError("[Error] At least one of the two dims should be 2.")
+
+ # Round and convert junctions to int (and check the boundary)
+ H, W = image.shape[:2]
+ junctions = (np.round(junctions)).astype(np.int32)
+ junctions[junctions < 0] = 0
+ junctions[junctions[:, 0] >= H, 0] = H-1 # (first dim) max bounded by H-1
+ junctions[junctions[:, 1] >= W, 1] = W-1 # (second dim) max bounded by W-1
+
+ # Iterate through all the junctions
+ num_junc = junctions.shape[0]
+ if color is None:
+ color = (0, 255., 0)
+ for idx in range(num_junc):
+ # Fetch one junction
+ junc = junctions[idx, :]
+ cv2.circle(image, tuple(np.flip(junc)), radius=junc_size,
+ color=color, thickness=3)
+
+ return image
+
+
+# Plot line segements given junctions and line adjecent map
+def plot_line_segments(input_image, junctions, line_map, junc_size=3,
+ color=(0, 255., 0), line_width=1, plot_survived_junc=True):
+ """
+ input_image: can be 0~1 float or 0~255 uint8.
+ junctions: Nx2 or 2xN np array.
+ line_map: NxN np array
+ junc_size: the size of the plotted circles.
+ color: color of the line segments (can be string "random")
+ line_width: width of the drawn segments.
+ plot_survived_junc: whether we only plot the survived junctions.
+ """
+ # Create image copy
+ image = copy.copy(input_image)
+ # Make sure the image is converted to 255 uint8
+ if image.dtype == np.uint8:
+ pass
+ # A float type image ranging from 0~1
+ elif image.dtype in [np.float32, np.float64, np.float] and image.max() <= 2.:
+ image = (image * 255.).astype(np.uint8)
+ # A float type image ranging from 0.~255.
+ elif image.dtype in [np.float32, np.float64, np.float] and image.mean() > 10.:
+ image = image.astype(np.uint8)
+ else:
+ raise ValueError("[Error] Unknown image data type. Expect 0~1 float or 0~255 uint8.")
+
+ # Check whether the image is single channel
+ if len(image.shape) == 2 or ((len(image.shape) == 3) and (image.shape[-1] == 1)):
+ # Squeeze to H*W first
+ image = image.squeeze()
+
+ # Stack to channle 3
+ image = np.concatenate([image[..., None] for _ in range(3)], axis=-1)
+
+ # Junction dimensions should be 2
+ if not len(junctions.shape) == 2:
+ raise ValueError("[Error] junctions should be 2-dim array.")
+
+ # Always convert to N*2
+ if junctions.shape[-1] != 2:
+ if junctions.shape[0] == 2:
+ junctions = junctions.T
+ else:
+ raise ValueError("[Error] At least one of the two dims should be 2.")
+
+ # line_map dimension should be 2
+ if not len(line_map.shape) == 2:
+ raise ValueError("[Error] line_map should be 2-dim array.")
+
+ # Color should be "random" or a list or tuple with length 3
+ if color != "random":
+ if not (isinstance(color, tuple) or isinstance(color, list)):
+ raise ValueError("[Error] color should have type list or tuple.")
+ else:
+ if len(color) != 3:
+ raise ValueError("[Error] color should be a list or tuple with length 3.")
+
+ # Make a copy of the line_map
+ line_map_tmp = copy.copy(line_map)
+
+ # Parse line_map back to segment pairs
+ segments = np.zeros([0, 4])
+ for idx in range(junctions.shape[0]):
+ # if no connectivity, just skip it
+ if line_map_tmp[idx, :].sum() == 0:
+ continue
+ # record the line segment
+ else:
+ for idx2 in np.where(line_map_tmp[idx, :] == 1)[0]:
+ p1 = np.flip(junctions[idx, :]) # Convert to xy format
+ p2 = np.flip(junctions[idx2, :]) # Convert to xy format
+ segments = np.concatenate((segments, np.array([p1[0], p1[1], p2[0], p2[1]])[None, ...]), axis=0)
+
+ # Update line_map
+ line_map_tmp[idx, idx2] = 0
+ line_map_tmp[idx2, idx] = 0
+
+ # Draw segment pairs
+ for idx in range(segments.shape[0]):
+ seg = np.round(segments[idx, :]).astype(np.int32)
+ # Decide the color
+ if color != "random":
+ color = tuple(color)
+ else:
+ color = tuple(np.random.rand(3,))
+ cv2.line(image, tuple(seg[:2]), tuple(seg[2:]), color=color, thickness=line_width)
+
+ # Also draw the junctions
+ if not plot_survived_junc:
+ num_junc = junctions.shape[0]
+ for idx in range(num_junc):
+ # Fetch one junction
+ junc = junctions[idx, :]
+ cv2.circle(image, tuple(np.flip(junc)), radius=junc_size,
+ color=(0, 255., 0), thickness=3)
+ # Only plot the junctions which are part of a line segment
+ else:
+ for idx in range(segments.shape[0]):
+ seg = np.round(segments[idx, :]).astype(np.int32) # Already in HW format.
+ cv2.circle(image, tuple(seg[:2]), radius=junc_size,
+ color=(0, 255., 0), thickness=3)
+ cv2.circle(image, tuple(seg[2:]), radius=junc_size,
+ color=(0, 255., 0), thickness=3)
+
+ return image
+
+
+# Plot line segments given Nx4 or Nx2x2 line segments
+def plot_line_segments_from_segments(input_image, line_segments, junc_size=3,
+ color=(0, 255., 0), line_width=1):
+ # Create image copy
+ image = copy.copy(input_image)
+ # Make sure the image is converted to 255 uint8
+ if image.dtype == np.uint8:
+ pass
+ # A float type image ranging from 0~1
+ elif image.dtype in [np.float32, np.float64, np.float] and image.max() <= 2.:
+ image = (image * 255.).astype(np.uint8)
+ # A float type image ranging from 0.~255.
+ elif image.dtype in [np.float32, np.float64, np.float] and image.mean() > 10.:
+ image = image.astype(np.uint8)
+ else:
+ raise ValueError("[Error] Unknown image data type. Expect 0~1 float or 0~255 uint8.")
+
+ # Check whether the image is single channel
+ if len(image.shape) == 2 or ((len(image.shape) == 3) and (image.shape[-1] == 1)):
+ # Squeeze to H*W first
+ image = image.squeeze()
+
+ # Stack to channle 3
+ image = np.concatenate([image[..., None] for _ in range(3)], axis=-1)
+
+ # Check the if line_segments are in (1) Nx4, or (2) Nx2x2.
+ H, W, _ = image.shape
+ # (1) Nx4 format
+ if len(line_segments.shape) == 2 and line_segments.shape[-1] == 4:
+ # Round to int32
+ line_segments = line_segments.astype(np.int32)
+
+ # Clip H dimension
+ line_segments[:, 0] = np.clip(line_segments[:, 0], a_min=0, a_max=H-1)
+ line_segments[:, 2] = np.clip(line_segments[:, 2], a_min=0, a_max=H-1)
+
+ # Clip W dimension
+ line_segments[:, 1] = np.clip(line_segments[:, 1], a_min=0, a_max=W-1)
+ line_segments[:, 3] = np.clip(line_segments[:, 3], a_min=0, a_max=W-1)
+
+ # Convert to Nx2x2 format
+ line_segments = np.concatenate(
+ [np.expand_dims(line_segments[:, :2], axis=1),
+ np.expand_dims(line_segments[:, 2:], axis=1)],
+ axis=1
+ )
+
+ # (2) Nx2x2 format
+ elif len(line_segments.shape) == 3 and line_segments.shape[-1] == 2:
+ # Round to int32
+ line_segments = line_segments.astype(np.int32)
+
+ # Clip H dimension
+ line_segments[:, :, 0] = np.clip(line_segments[:, :, 0], a_min=0, a_max=H-1)
+ line_segments[:, :, 1] = np.clip(line_segments[:, :, 1], a_min=0, a_max=W-1)
+
+ else:
+ raise ValueError("[Error] line_segments should be either Nx4 or Nx2x2 in HW format.")
+
+ # Draw segment pairs (all segments should be in HW format)
+ image = image.copy()
+ for idx in range(line_segments.shape[0]):
+ seg = np.round(line_segments[idx, :, :]).astype(np.int32)
+ # Decide the color
+ if color != "random":
+ color = tuple(color)
+ else:
+ color = tuple(np.random.rand(3,))
+ cv2.line(image, tuple(np.flip(seg[0, :])),
+ tuple(np.flip(seg[1, :])),
+ color=color, thickness=line_width)
+
+ # Also draw the junctions
+ cv2.circle(image, tuple(np.flip(seg[0, :])), radius=junc_size, color=(0, 255., 0), thickness=3)
+ cv2.circle(image, tuple(np.flip(seg[1, :])), radius=junc_size, color=(0, 255., 0), thickness=3)
+
+ return image
+
+
+# Additional functions to visualize multiple images at the same time,
+# e.g. for line matching
+def plot_images(imgs, titles=None, cmaps='gray', dpi=100, size=6, pad=.5):
+ """Plot a set of images horizontally.
+ Args:
+ imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W).
+ titles: a list of strings, as titles for each image.
+ cmaps: colormaps for monochrome images.
+ """
+ n = len(imgs)
+ if not isinstance(cmaps, (list, tuple)):
+ cmaps = [cmaps] * n
+ figsize = (size*n, size*3/4) if size is not None else None
+ fig, ax = plt.subplots(1, n, figsize=figsize, dpi=dpi)
+ if n == 1:
+ ax = [ax]
+ for i in range(n):
+ ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i]))
+ ax[i].get_yaxis().set_ticks([])
+ ax[i].get_xaxis().set_ticks([])
+ ax[i].set_axis_off()
+ for spine in ax[i].spines.values(): # remove frame
+ spine.set_visible(False)
+ if titles:
+ ax[i].set_title(titles[i])
+ fig.tight_layout(pad=pad)
+
+
+def plot_keypoints(kpts, colors='lime', ps=4):
+ """Plot keypoints for existing images.
+ Args:
+ kpts: list of ndarrays of size (N, 2).
+ colors: string, or list of list of tuples (one for each keypoints).
+ ps: size of the keypoints as float.
+ """
+ if not isinstance(colors, list):
+ colors = [colors] * len(kpts)
+ axes = plt.gcf().axes
+ for a, k, c in zip(axes, kpts, colors):
+ a.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0)
+
+
+def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.):
+ """Plot matches for a pair of existing images.
+ Args:
+ kpts0, kpts1: corresponding keypoints of size (N, 2).
+ color: color of each match, string or RGB tuple. Random if not given.
+ lw: width of the lines.
+ ps: size of the end points (no endpoint if ps=0)
+ indices: indices of the images to draw the matches on.
+ a: alpha opacity of the match lines.
+ """
+ fig = plt.gcf()
+ ax = fig.axes
+ assert len(ax) > max(indices)
+ ax0, ax1 = ax[indices[0]], ax[indices[1]]
+ fig.canvas.draw()
+
+ assert len(kpts0) == len(kpts1)
+ if color is None:
+ color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist()
+ elif len(color) > 0 and not isinstance(color[0], (tuple, list)):
+ color = [color] * len(kpts0)
+
+ if lw > 0:
+ # transform the points into the figure coordinate system
+ transFigure = fig.transFigure.inverted()
+ fkpts0 = transFigure.transform(ax0.transData.transform(kpts0))
+ fkpts1 = transFigure.transform(ax1.transData.transform(kpts1))
+ fig.lines += [matplotlib.lines.Line2D(
+ (fkpts0[i, 0], fkpts1[i, 0]), (fkpts0[i, 1], fkpts1[i, 1]),
+ zorder=1, transform=fig.transFigure, c=color[i], linewidth=lw,
+ alpha=a)
+ for i in range(len(kpts0))]
+
+ # freeze the axes to prevent the transform to change
+ ax0.autoscale(enable=False)
+ ax1.autoscale(enable=False)
+
+ if ps > 0:
+ ax0.scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps, zorder=2)
+ ax1.scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps, zorder=2)
+
+
+def plot_lines(lines, line_colors='orange', point_colors='cyan',
+ ps=4, lw=2, indices=(0, 1)):
+ """Plot lines and endpoints for existing images.
+ Args:
+ lines: list of ndarrays of size (N, 2, 2).
+ colors: string, or list of list of tuples (one for each keypoints).
+ ps: size of the keypoints as float pixels.
+ lw: line width as float pixels.
+ indices: indices of the images to draw the matches on.
+ """
+ if not isinstance(line_colors, list):
+ line_colors = [line_colors] * len(lines)
+ if not isinstance(point_colors, list):
+ point_colors = [point_colors] * len(lines)
+
+ fig = plt.gcf()
+ ax = fig.axes
+ assert len(ax) > max(indices)
+ axes = [ax[i] for i in indices]
+ fig.canvas.draw()
+
+ # Plot the lines and junctions
+ for a, l, lc, pc in zip(axes, lines, line_colors, point_colors):
+ for i in range(len(l)):
+ line = matplotlib.lines.Line2D((l[i, 0, 0], l[i, 1, 0]),
+ (l[i, 0, 1], l[i, 1, 1]),
+ zorder=1, c=lc, linewidth=lw)
+ a.add_line(line)
+ pts = l.reshape(-1, 2)
+ a.scatter(pts[:, 0], pts[:, 1],
+ c=pc, s=ps, linewidths=0, zorder=2)
+
+
+def plot_line_matches(kpts0, kpts1, color=None, lw=1.5, indices=(0, 1), a=1.):
+ """Plot matches for a pair of existing images, parametrized by their middle point.
+ Args:
+ kpts0, kpts1: corresponding middle points of the lines of size (N, 2).
+ color: color of each match, string or RGB tuple. Random if not given.
+ lw: width of the lines.
+ indices: indices of the images to draw the matches on.
+ a: alpha opacity of the match lines.
+ """
+ fig = plt.gcf()
+ ax = fig.axes
+ assert len(ax) > max(indices)
+ ax0, ax1 = ax[indices[0]], ax[indices[1]]
+ fig.canvas.draw()
+
+ assert len(kpts0) == len(kpts1)
+ if color is None:
+ color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist()
+ elif len(color) > 0 and not isinstance(color[0], (tuple, list)):
+ color = [color] * len(kpts0)
+
+ if lw > 0:
+ # transform the points into the figure coordinate system
+ transFigure = fig.transFigure.inverted()
+ fkpts0 = transFigure.transform(ax0.transData.transform(kpts0))
+ fkpts1 = transFigure.transform(ax1.transData.transform(kpts1))
+ fig.lines += [matplotlib.lines.Line2D(
+ (fkpts0[i, 0], fkpts1[i, 0]), (fkpts0[i, 1], fkpts1[i, 1]),
+ zorder=1, transform=fig.transFigure, c=color[i], linewidth=lw,
+ alpha=a)
+ for i in range(len(kpts0))]
+
+ # freeze the axes to prevent the transform to change
+ ax0.autoscale(enable=False)
+ ax1.autoscale(enable=False)
+
+
+def plot_color_line_matches(lines, correct_matches=None,
+ lw=2, indices=(0, 1)):
+ """Plot line matches for existing images with multiple colors.
+ Args:
+ lines: list of ndarrays of size (N, 2, 2).
+ correct_matches: bool array of size (N,) indicating correct matches.
+ lw: line width as float pixels.
+ indices: indices of the images to draw the matches on.
+ """
+ n_lines = len(lines[0])
+ colors = sns.color_palette('husl', n_colors=n_lines)
+ np.random.shuffle(colors)
+ alphas = np.ones(n_lines)
+ # If correct_matches is not None, display wrong matches with a low alpha
+ if correct_matches is not None:
+ alphas[~np.array(correct_matches)] = 0.2
+
+ fig = plt.gcf()
+ ax = fig.axes
+ assert len(ax) > max(indices)
+ axes = [ax[i] for i in indices]
+ fig.canvas.draw()
+
+ # Plot the lines
+ for a, l in zip(axes, lines):
+ # Transform the points into the figure coordinate system
+ transFigure = fig.transFigure.inverted()
+ endpoint0 = transFigure.transform(a.transData.transform(l[:, 0]))
+ endpoint1 = transFigure.transform(a.transData.transform(l[:, 1]))
+ fig.lines += [matplotlib.lines.Line2D(
+ (endpoint0[i, 0], endpoint1[i, 0]),
+ (endpoint0[i, 1], endpoint1[i, 1]),
+ zorder=1, transform=fig.transFigure, c=colors[i],
+ alpha=alphas[i], linewidth=lw) for i in range(n_lines)]
+
+
+def plot_color_lines(lines, correct_matches, wrong_matches,
+ lw=2, indices=(0, 1)):
+ """Plot line matches for existing images with multiple colors:
+ green for correct matches, red for wrong ones, and blue for the rest.
+ Args:
+ lines: list of ndarrays of size (N, 2, 2).
+ correct_matches: list of bool arrays of size N with correct matches.
+ wrong_matches: list of bool arrays of size (N,) with correct matches.
+ lw: line width as float pixels.
+ indices: indices of the images to draw the matches on.
+ """
+ # palette = sns.color_palette()
+ palette = sns.color_palette("hls", 8)
+ blue = palette[5] # palette[0]
+ red = palette[0] # palette[3]
+ green = palette[2] # palette[2]
+ colors = [np.array([blue] * len(l)) for l in lines]
+ for i, c in enumerate(colors):
+ c[np.array(correct_matches[i])] = green
+ c[np.array(wrong_matches[i])] = red
+
+ fig = plt.gcf()
+ ax = fig.axes
+ assert len(ax) > max(indices)
+ axes = [ax[i] for i in indices]
+ fig.canvas.draw()
+
+ # Plot the lines
+ for a, l, c in zip(axes, lines, colors):
+ # Transform the points into the figure coordinate system
+ transFigure = fig.transFigure.inverted()
+ endpoint0 = transFigure.transform(a.transData.transform(l[:, 0]))
+ endpoint1 = transFigure.transform(a.transData.transform(l[:, 1]))
+ fig.lines += [matplotlib.lines.Line2D(
+ (endpoint0[i, 0], endpoint1[i, 0]),
+ (endpoint0[i, 1], endpoint1[i, 1]),
+ zorder=1, transform=fig.transFigure, c=c[i],
+ linewidth=lw) for i in range(len(l))]
+
+
+def plot_subsegment_matches(lines, subsegments, lw=2, indices=(0, 1)):
+ """ Plot line matches for existing images with multiple colors and
+ highlight the actually matched subsegments.
+ Args:
+ lines: list of ndarrays of size (N, 2, 2).
+ subsegments: list of ndarrays of size (N, 2, 2).
+ lw: line width as float pixels.
+ indices: indices of the images to draw the matches on.
+ """
+ n_lines = len(lines[0])
+ colors = sns.cubehelix_palette(start=2, rot=-0.2, dark=0.3, light=.7,
+ gamma=1.3, hue=1, n_colors=n_lines)
+
+ fig = plt.gcf()
+ ax = fig.axes
+ assert len(ax) > max(indices)
+ axes = [ax[i] for i in indices]
+ fig.canvas.draw()
+
+ # Plot the lines
+ for a, l, ss in zip(axes, lines, subsegments):
+ # Transform the points into the figure coordinate system
+ transFigure = fig.transFigure.inverted()
+
+ # Draw full line
+ endpoint0 = transFigure.transform(a.transData.transform(l[:, 0]))
+ endpoint1 = transFigure.transform(a.transData.transform(l[:, 1]))
+ fig.lines += [matplotlib.lines.Line2D(
+ (endpoint0[i, 0], endpoint1[i, 0]),
+ (endpoint0[i, 1], endpoint1[i, 1]),
+ zorder=1, transform=fig.transFigure, c='red',
+ alpha=0.7, linewidth=lw) for i in range(n_lines)]
+
+ # Draw matched subsegment
+ endpoint0 = transFigure.transform(a.transData.transform(ss[:, 0]))
+ endpoint1 = transFigure.transform(a.transData.transform(ss[:, 1]))
+ fig.lines += [matplotlib.lines.Line2D(
+ (endpoint0[i, 0], endpoint1[i, 0]),
+ (endpoint0[i, 1], endpoint1[i, 1]),
+ zorder=1, transform=fig.transFigure, c=colors[i],
+ alpha=1, linewidth=lw) for i in range(n_lines)]
\ No newline at end of file
diff --git a/scalelsd/ssl/models/__init__.py b/scalelsd/ssl/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0f33847ebe1b18e6cfa9338d0c71ee4002f65ac
--- /dev/null
+++ b/scalelsd/ssl/models/__init__.py
@@ -0,0 +1,3 @@
+
+from .detector import ScaleLSD
+from . import detector
diff --git a/scalelsd/ssl/models/detector.py b/scalelsd/ssl/models/detector.py
new file mode 100644
index 0000000000000000000000000000000000000000..32c03c6fcd7130eb6f13c8ffc852e1e849de4116
--- /dev/null
+++ b/scalelsd/ssl/models/detector.py
@@ -0,0 +1,362 @@
+from collections import defaultdict
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+import numpy as np
+import time
+
+from ..backbones import build_backbone
+from .hafm import HAFMencoder
+from .losses import *
+
+import math
+import cv2
+import matplotlib.pyplot as plt
+
+class ScaleLSD(nn.Module):
+ def __init__(self, gray_scale=False, use_layer_scale=False, enable_attention_hooks=False):
+
+ super(ScaleLSD, self).__init__()
+
+ num_junctions_inference = 512
+ junction_threshold_hm = 0.008
+
+ self.distance_threshold = 5.0
+
+ self.hafm_encoder = HAFMencoder(dis_th=self.distance_threshold)
+
+ # self.backbone = build_backbone(gray_scale=gray_scale, use_layer_scale=use_layer_scale)
+ self.backbone = build_backbone(gray_scale=gray_scale, use_layer_scale=use_layer_scale, enable_attention_hooks=enable_attention_hooks)
+
+ self.j2l_threshold = 10
+
+ self.num_residuals = 0
+
+ self.loss = nn.CrossEntropyLoss(reduction='none')
+ self.bce_loss = nn.BCEWithLogitsLoss(reduction='none')
+ self.stride = self.backbone.stride
+ self.train_step = 0
+
+ @classmethod
+ def configure(cls, opts):
+ try:
+ cls.num_junctions_inference = opts.num_junctions
+ cls.junction_threshold_hm = opts.junction_hm
+ except:
+ pass
+
+ @classmethod
+ def cli(cls, parser):
+ try:
+ parser.add_argument('-nj', '--num-junctions', default=512, type=int, help='number of junctions')
+ parser.add_argument('-jh', '--junction-hm', default=0.008, type=float, help='junction threshold heatmap')
+ except:
+ pass
+
+ def hafm_decoding(self,md_maps, dis_maps, residual_maps, scale=5.0, flatten = True, return_points = False):
+
+ device = md_maps.device
+ scale = self.distance_threshold
+
+ batch_size, _, height, width = md_maps.shape
+ _y = torch.arange(0,height,device=device).float()
+ _x = torch.arange(0,width, device=device).float()
+
+ y0, x0 =torch.meshgrid(_y, _x,indexing='ij')
+ y0 = y0[None,None]
+ x0 = x0[None,None]
+
+ sign_pad = torch.arange(-self.num_residuals,self.num_residuals+1,device=device,dtype=torch.float32).reshape(1,-1,1,1)
+
+ if residual_maps is not None:
+ residual = residual_maps*sign_pad
+ distance_fields = dis_maps + residual
+ else:
+ distance_fields = dis_maps
+ distance_fields = distance_fields.clamp(min=0,max=1.0)
+ md_un = (md_maps[:,:1] - 0.5)*np.pi*2
+ st_un = md_maps[:,1:2]*np.pi/2.0
+ ed_un = -md_maps[:,2:3]*np.pi/2.0
+
+ cs_md = md_un.cos()
+ ss_md = md_un.sin()
+
+ y_st = torch.tan(st_un)
+ y_ed = torch.tan(ed_un)
+
+ x_st_rotated = (cs_md - ss_md*y_st)*distance_fields*scale
+ y_st_rotated = (ss_md + cs_md*y_st)*distance_fields*scale
+
+ x_ed_rotated = (cs_md - ss_md*y_ed)*distance_fields*scale
+ y_ed_rotated = (ss_md + cs_md*y_ed)*distance_fields*scale
+
+ x_st_final = (x_st_rotated + x0).clamp(min=0,max=width-1)
+ y_st_final = (y_st_rotated + y0).clamp(min=0,max=height-1)
+
+ x_ed_final = (x_ed_rotated + x0).clamp(min=0,max=width-1)
+ y_ed_final = (y_ed_rotated + y0).clamp(min=0,max=height-1)
+
+
+ lines = torch.stack((x_st_final,y_st_final,x_ed_final,y_ed_final),dim=-1)
+ if flatten:
+ lines = lines.reshape(batch_size,-1,4)
+ if return_points:
+ points = torch.stack((x0,y0),dim=-1)
+ points = points.repeat((batch_size,2*self.num_residuals+1,1,1,1))
+ if flatten:
+ points = points.reshape(batch_size,-1,2)
+ return lines, points
+
+ return lines
+
+ @staticmethod
+ def non_maximum_suppression(a, kernel_size=3):
+ ap = F.max_pool2d(a, kernel_size, stride=1, padding=kernel_size//2)
+ mask = (a == ap).float().clamp(min=0.0)
+
+ return a * mask
+
+ @staticmethod
+ def get_junctions(jloc, joff, topk = 300, th=0):
+ height, width = jloc.size(1), jloc.size(2)
+ jloc = jloc.reshape(-1)
+ joff = joff.reshape(2, -1)
+
+
+ scores, index = torch.topk(jloc, k=topk)
+ # y = (index // width).float() + torch.gather(joff[1], 0, index) + 0.5
+ y = torch.div(index,width,rounding_mode='trunc').float()+ torch.gather(joff[1], 0, index) + 0.5
+ x = (index % width).float() + torch.gather(joff[0], 0, index) + 0.5
+
+ junctions = torch.stack((x, y)).t()
+
+ if th>0 :
+ return junctions[scores>th], scores[scores>th]
+ else:
+ return junctions, scores
+
+ def wireframe_matcher(self, juncs_pred, lines_pred, hat_points, is_train=False):
+ cost1 = torch.sum((lines_pred[:,:2]-juncs_pred[:,None])**2,dim=-1)
+ cost2 = torch.sum((lines_pred[:,2:]-juncs_pred[:,None])**2,dim=-1)
+
+ dis1, idx_junc_to_end1 = cost1.min(dim=0)
+ dis2, idx_junc_to_end2 = cost2.min(dim=0)
+ length = torch.sum((lines_pred[:,:2]-lines_pred[:,2:])**2,dim=-1)
+
+ idx_junc_to_end_min = torch.min(idx_junc_to_end1,idx_junc_to_end2)
+ idx_junc_to_end_max = torch.max(idx_junc_to_end1,idx_junc_to_end2)
+
+ iskeep = idx_junc_to_end_min < idx_junc_to_end_max ## not the same junction
+ if self.j2l_threshold>0:
+ iskeep *= (dis10:
+ for auxput in auxputs:
+ loss_dict = self.compute_loss(auxput, targets, mask, loss_dict)
+
+ for key in extra_info.keys():
+ extra_info[key] = extra_info[key]/batch_size
+
+ return loss_dict, extra_info
+
+
+ @torch.no_grad()
+ def forward_test(self, images, annotations=None, merge=False):
+ device = images.device
+ batch_size, _, height, width = images.shape
+
+ outputs, features, aux = self.forward_backbone(images)
+
+ if "use_lsd" not in annotations.keys():
+ annotations["use_lsd"] = True
+
+ # use lsd for theta prediction
+ if annotations['use_lsd']:
+ ws = images.shape[3]//self.stride
+ hs = images.shape[2]//self.stride
+ lsd = cv2.createLineSegmentDetector(0)
+
+ md_lsd_batch = []
+ dis_lsd_batch = []
+ for i in range(batch_size):
+ image = np.array(images[i,0].cpu().numpy()*255,dtype=np.uint8)
+ lsd_lines = lsd.detect(image)[0].reshape(-1,4)
+
+ # transform lsd lines to lsd-hat-field
+ md_lsd, dis_lsd, _ = self.hafm_encoder.lines2hafm(torch.from_numpy(lsd_lines).to(images.device)/self.stride, hs, ws)
+ md_lsd_batch.append(md_lsd)
+ dis_lsd_batch.append(dis_lsd)
+
+ md_pred = torch.stack(md_lsd_batch, dim=0)
+ dis_pred = torch.stack(dis_lsd_batch, dim=0)
+
+ # for junctions/endpoints extraction
+ md_pred[:,1:3] = outputs[:,1:3].sigmoid()
+ # dist
+ dis_pred = outputs[:,3:4].sigmoid()
+
+ jloc_pred= outputs[:,5:7].softmax(1)[:,1:]
+ joff_pred= outputs[:,7:9].sigmoid() - 0.5
+ else:
+
+ md_pred = outputs[:,:3].sigmoid()
+ dis_pred = outputs[:,3:4].sigmoid()
+ res_pred = outputs[:,4:5].sigmoid()
+ jloc_pred= outputs[:,5:7].softmax(1)[:,1:]
+ jloc_logits = outputs[:,5:7].softmax(1)
+ joff_pred= outputs[:,7:9].sigmoid() - 0.5
+
+ lines_pred_batch, hat_points_batch = self.hafm_decoding(md_pred, dis_pred, None, flatten = True, return_points=True)
+
+ output_list = []
+ graph_pred = torch.zeros((batch_size, self.num_junctions_inference, self.num_junctions_inference), device=images.device)
+ for i in range(batch_size):
+ if annotations['use_nms']:
+ jloc_pred_nms = self.non_maximum_suppression(jloc_pred[i])
+ else:
+ jloc_pred_nms = self.non_maximum_suppression(jloc_pred[i], kernel_size=1)
+ topK = min(self.num_junctions_inference, int((jloc_pred_nms>self.junction_threshold_hm).float().sum().item()))
+ juncs_pred, juncs_score = self.get_junctions(jloc_pred_nms,joff_pred[i], topk=topK, th=self.junction_threshold_hm)
+
+ lines_adjusted, indices_adj, supports, hat_points, counts = self.wireframe_matcher(juncs_pred, lines_pred_batch[i], hat_points_batch[i])
+
+ jscales = torch.tensor(
+ [
+ annotations['width']/md_pred.size(3),
+ annotations['height']/md_pred.size(2)
+ ],
+ device=images.device
+ )
+
+ junctions = juncs_pred * jscales
+ supports = [_*self.stride for _ in supports]
+
+ num_junctions = junctions.shape[0]
+ graph_pred[i, indices_adj[:,0], indices_adj[:,1]] += counts.float()
+ graph_pred[i, indices_adj[:,1], indices_adj[:,0]] += counts.float()
+ graph_i = graph_pred[i,:num_junctions,:num_junctions]
+ edges = graph_i.triu().nonzero()
+ lines = junctions[edges].reshape(-1,4)
+ scores = graph_pred[i, edges[:,0], edges[:,1]]
+
+ output_list.append(
+ {
+ 'lines_pred': lines,
+ 'lines_score': scores,
+ 'juncs_pred': junctions,
+ 'lines_support': supports,
+ 'juncs_score': juncs_score,
+ 'graph': graph_i,
+ 'width': annotations['width'],
+ 'height': annotations['height'],
+ }
+ )
+
+ return output_list, {}
+
diff --git a/scalelsd/ssl/models/hafm.py b/scalelsd/ssl/models/hafm.py
new file mode 100644
index 0000000000000000000000000000000000000000..89c1d2dd24a713cd07a8fc25896997d7694770b2
--- /dev/null
+++ b/scalelsd/ssl/models/hafm.py
@@ -0,0 +1,161 @@
+import torch
+import numpy as np
+from torch.utils.data.dataloader import default_collate
+
+from scalelsd.base.csrc import _C
+
+class HAFMencoder(object):
+ def __init__(self, dis_th = 10, ang_th = 0):
+ self.dis_th = dis_th
+ self.ang_th = ang_th
+
+ def __call__(self,annotations):
+ targets = []
+ metas = []
+ batch_size = annotations['batch_size']
+ stride = annotations['stride']
+ for batch_id in range(batch_size):
+
+ junctions = annotations['junctions'][batch_id].clone()[:,[1,0]]/float(stride)
+
+ width = annotations['width']//stride
+ height = annotations['height']//stride
+ edge_indices = annotations['line_map'][batch_id].triu().nonzero()
+
+ t, m = self.encoding_single_image(junctions,edge_indices,height,width)
+
+ targets.append(t)
+ metas.append(m)
+
+ return default_collate(targets),metas
+
+ def adjacent_matrix(self, n, edges, device):
+ mat = torch.zeros(n+1,n+1,dtype=torch.bool,device=device)
+ if edges.size(0)>0:
+ mat[edges[:,0], edges[:,1]] += True
+ mat[edges[:,1], edges[:,0]] += True
+ return mat
+
+ def lines2hafm(self, lines, height, width):
+ device = lines.device
+ if lines.shape[0] == 0:
+ hafm_ang = torch.zeros((3,height,width),device=device)
+ hafm_dis = torch.zeros((1,height,width),device=device)
+ hafm_mask = torch.zeros((1,height,width),device=device)
+ return torch.zeros((3,height,width),device=device), torch.zeros((1,height,width),device=device), torch.zeros((1,height,width),device=device)
+
+ lmap, _, _ = _C.encodels(lines,height,width,height,width,lines.size(0))
+ dismap = torch.sqrt(lmap[0]**2+lmap[1]**2)[None]
+ def _normalize(inp):
+ mag = torch.sqrt(inp[0]*inp[0]+inp[1]*inp[1])
+ return inp/(mag+1e-6)
+ md_map = _normalize(lmap[:2])
+ st_map = _normalize(lmap[2:4])
+ ed_map = _normalize(lmap[4:])
+ st_map = lmap[2:4]
+ ed_map = lmap[4:]
+
+ md_ = md_map.reshape(2,-1).t()
+ st_ = st_map.reshape(2,-1).t()
+ ed_ = ed_map.reshape(2,-1).t()
+ Rt = torch.cat(
+ (torch.cat((md_[:,None,None,0],md_[:,None,None,1]),dim=2),
+ torch.cat((-md_[:,None,None,1], md_[:,None,None,0]),dim=2)),dim=1)
+ R = torch.cat(
+ (torch.cat((md_[:,None,None,0], -md_[:,None,None,1]),dim=2),
+ torch.cat((md_[:,None,None,1], md_[:,None,None,0]),dim=2)),dim=1)
+ #Rtst_ = torch.matmul(Rt, st_[:,:,None]).squeeze(-1).t()
+ #Rted_ = torch.matmul(Rt, ed_[:,:,None]).squeeze(-1).t()
+ Rtst_ = torch.bmm(Rt, st_[:,:,None]).squeeze(-1).t()
+ Rted_ = torch.bmm(Rt, ed_[:,:,None]).squeeze(-1).t()
+ swap_mask = (Rtst_[1]<0)*(Rted_[1]>0)
+ pos_ = Rtst_.clone()
+ neg_ = Rted_.clone()
+ temp = pos_[:,swap_mask]
+ pos_[:,swap_mask] = neg_[:,swap_mask]
+ neg_[:,swap_mask] = temp
+
+ pos_[0] = pos_[0].clamp(min=1e-9)
+ pos_[1] = pos_[1].clamp(min=1e-9)
+ neg_[0] = neg_[0].clamp(min=1e-9)
+ neg_[1] = neg_[1].clamp(max=-1e-9)
+
+ mask = (dismap.view(-1)<=self.dis_th).float()
+
+ pos_map = pos_.reshape(-1,height,width)
+ neg_map = neg_.reshape(-1,height,width)
+
+ md_angle = torch.atan2(md_map[1], md_map[0])
+ pos_angle = torch.atan2(pos_map[1],pos_map[0])
+ neg_angle = torch.atan2(neg_map[1],neg_map[0])
+
+ mask *= (pos_angle.reshape(-1)>self.ang_th*np.pi/2.0)
+ mask *= (neg_angle.reshape(-1)<-self.ang_th*np.pi/2.0)
+
+ pos_angle_n = pos_angle/(np.pi/2)
+ neg_angle_n = -neg_angle/(np.pi/2)
+ md_angle_n = md_angle/(np.pi*2) + 0.5
+ mask = mask.reshape(height,width)
+
+
+ hafm_ang = torch.cat((md_angle_n[None],pos_angle_n[None],neg_angle_n[None],),dim=0)
+ hafm_dis = dismap.clamp(max=self.dis_th)/self.dis_th
+ mask = mask[None]
+ return hafm_ang, hafm_dis, mask
+
+ def encoding_single_image(self, junctions, edge_indices, height, width):
+ device = junctions.device
+
+ # jmap = torch.zeros((height,width),device=device)
+ # joff = torch.zeros((2,height,width),device=device,dtype=torch.float32)
+ jmap = np.zeros((height,width),dtype=np.float32)
+ joff = np.zeros((2,height,width),dtype=np.float32)
+
+ dx, dy = np.meshgrid(np.arange(width), np.arange(height))
+ # gaussian = np.exp(-(dx**2+dy**2)/2.0/2.0**2)
+
+ if junctions.shape[0] > 0:
+ junctions_np = junctions.cpu().numpy()
+ xint, yint = junctions_np[:,0].astype(np.int32), junctions_np[:,1].astype(np.int32)
+ off_x = junctions_np[:,0] - np.floor(junctions_np[:,0]) - 0.5
+ off_y = junctions_np[:,1] - np.floor(junctions_np[:,1]) - 0.5
+
+ jmap[yint,xint] = 1#= jmap[yint,xint] + 1
+ joff[0,yint,xint] = off_x
+ joff[1,yint,xint] = off_y
+
+ lines = junctions[edge_indices].reshape(-1,4)
+ pos_mat = self.adjacent_matrix(junctions.size(0), edge_indices, device)
+ labels = torch.ones((lines.shape[0],),device=device)
+ else:
+ lines = torch.empty((0,4),device=device)
+ pos_mat = None
+ labels = None
+ # for _x,_y in junctions.cpu().numpy():
+ # _map = np.exp(-((dx-_x)**2+(dy-_y)**2)/2.0/8.0**2)
+ # _map /= _map.max()
+ # jmap = np.maximum(jmap,_map)
+ # import matplotlib.pyplot as plt
+ # import pdb; pdb.set_trace()
+ jmap = torch.from_numpy(jmap).to(device)
+ joff = torch.from_numpy(joff).to(device)
+ hafm_ang, hafm_dis, hafm_mask = self.lines2hafm(lines,height,width)
+
+
+ target = {
+ 'jloc': jmap[None],
+ 'joff': joff,
+ 'md': hafm_ang,
+ 'dis': hafm_dis,
+ 'mask': hafm_mask
+ }
+
+ meta = {
+ 'junc': junctions,
+ 'lines': lines,
+ 'Lpos': pos_mat,
+ 'lpre': lines,
+ 'lpre_label': labels
+ }
+ return target, meta
+
\ No newline at end of file
diff --git a/scalelsd/ssl/models/losses.py b/scalelsd/ssl/models/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4fc7fc8c5028a60117f89174b6c2be6325635e2
--- /dev/null
+++ b/scalelsd/ssl/models/losses.py
@@ -0,0 +1,71 @@
+import torch
+import torch.nn.functional as F
+def cross_entropy_loss_for_junction(logits, positive):
+ nlogp = -F.log_softmax(logits, dim=1)
+
+ loss = (positive * nlogp[:, None, 1] + (1 - positive) * nlogp[:, None, 0])
+
+ return loss.mean()
+
+def focal_loss_for_junction(logits, positive, gamma=2.0):
+ prob = F.softmax(logits, 1)
+ ce_loss = F.cross_entropy(logits, positive, reduction='none')
+ p_t = prob[:,1:]*positive + prob[:,:1]*(1-positive)
+ loss = ce_loss * ((1-p_t)**gamma)
+
+ return loss.mean()
+
+def sigmoid_l1_loss(logits, targets, offset = 0.0, mask=None):
+ logp = torch.sigmoid(logits) + offset
+ loss = torch.abs(logp-targets)
+
+ if mask is not None:
+ w = mask.mean(3, True).mean(2,True)
+ w[w==0] = 1
+ loss = loss*(mask/w)
+
+ return loss.mean()
+
+def sigmoid_focal_loss(
+ inputs: torch.Tensor,
+ targets: torch.Tensor,
+ alpha: float = -1,
+ gamma: float = 2,
+ reduction: str = "none",
+) -> torch.Tensor:
+ """
+ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
+ Args:
+ inputs: A float tensor of arbitrary shape.
+ The predictions for each example.
+ targets: A float tensor with the same shape as inputs. Stores the binary
+ classification label for each element in inputs
+ (0 for the negative class and 1 for the positive class).
+ alpha: (optional) Weighting factor in range (0,1) to balance
+ positive vs negative examples. Default = -1 (no weighting).
+ gamma: Exponent of the modulating factor (1 - p_t) to
+ balance easy vs hard examples.
+ reduction: 'none' | 'mean' | 'sum'
+ 'none': No reduction will be applied to the output.
+ 'mean': The output will be averaged.
+ 'sum': The output will be summed.
+ Returns:
+ Loss tensor with the reduction option applied.
+ """
+ inputs = inputs.float()
+ targets = targets.float()
+ p = torch.sigmoid(inputs)
+ ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
+ p_t = p * targets + (1 - p) * (1 - targets)
+ loss = ce_loss * ((1 - p_t) ** gamma)
+
+ if alpha >= 0:
+ alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
+ loss = alpha_t * loss
+
+ if reduction == "mean":
+ loss = loss.mean()
+ elif reduction == "sum":
+ loss = loss.sum()
+
+ return loss
\ No newline at end of file
diff --git a/scripts/predict.sh b/scripts/predict.sh
new file mode 100644
index 0000000000000000000000000000000000000000..8074ee0e98c2c28a56325b50e23ab19ec4aeba65
--- /dev/null
+++ b/scripts/predict.sh
@@ -0,0 +1,18 @@
+
+test_root=data-ssl/hybrid_dataset/COCO/val2017
+
+threshold=10
+jhm=0.1
+res=512
+
+python -m predictor.predict \
+ --img $test_root \
+ --ext png \
+ --threshold $threshold \
+ --width $res --height $res \
+ --junction-hm $jhm \
+ --disable-show \
+ # --use_lsd \
+ # --use_nms \
+ # --whitebg 1.
+
\ No newline at end of file
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..c93b4b21e82747bad5e69b9720263bdf7e4f35ba
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,37 @@
+import glob
+import os
+
+from setuptools import find_packages
+from setuptools import setup
+
+setup(
+ name="scalelsd",
+ version="1.0",
+ author="Zeran Ke and Nan Xue",
+ description="Scalable Deep Line Segment Detection Streamlined",
+ packages=find_packages(),
+ install_requires=[
+ "torch",
+ "torchvision",
+ "accelerate",
+ "tensorboard",
+ "timm",
+ "opencv-python-headless==4.8.1.78",
+ "kornia",
+ "cython",
+ "matplotlib",
+ "yacs",
+ "scikit-image",
+ "tqdm",
+ "python-json-logger",
+ "h5py",
+ "shapely",
+ "seaborn",
+ "easydict",
+ ],
+ extras_require={
+ "dev": [
+ "pycolmap",
+ ]
+ }
+)