diff --git a/LEGAL.md b/LEGAL.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..6fdb1dfc367716a58d514323d9ba5391fefa50d0 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Nan Xue + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index bd6988349f832ccae381a2419b7eedbb24b169a8..c03789130684cad6c265d4be96863e661d1aed7c 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,117 @@ ---- -title: ScaleLSD -emoji: 🌍 -colorFrom: indigo -colorTo: indigo -sdk: gradio -sdk_version: 5.33.0 -app_file: app.py -pinned: false ---- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +
+ +# ScaleLSD: Scalable Deep Line Segment Detection Streamlined + + + + + + + +[Zeran Ke](https://calmke.github.io/)1,2, [Bin Tan](https://icetttb.github.io/)2, [Xianwei Zheng](https://jszy.whu.edu.cn/zhengxianwei/zh_CN/index.htm)1, [Yujun Shen](https://shenyujun.github.io/)2, [Tianfu Wu](https://research.ece.ncsu.edu/ivmcl/)3, [Nan Xue](https://xuenan.net/)2† + +1Wuhan University   2Ant Group  3NC State University + +
+ + + +![teaser](assets/teaser.jpg) + + +## βš™οΈ Installtion + +All codes are succefully tested on: + +- Ubuntu 22.04.5 LTS +- CUDA 12.1 +- Python 3.10 +- Pytorch 2.5.1 + +First clone this repo: + +```bash +git clone https://github.com/ant-research/scalelsd.git +``` + +Then create the conda eviroment and install the dependencies: +```bash +conda create -n scalelsd python=3.10 +pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu121 +pip install -r requirements.txt +pip install -e . # Install scalelsd locally +``` + +## πŸ”₯πŸ” Gradio Demo + +### Line Segment Detection +Before you started, please download our pre-trained [models](https://huggingface.co/cherubicxn/scalelsd) and place them into the `models` folder. Then run the Gradio demo: +```bash +python -m gradio_demo.inference +``` + +### Line Matching +Because our line matching app is built on GlueStick with our ScaleLSD, you need to install [GlueStick](https://github.com/cvg/GlueStick) and download the weights of the GlueStick model. Then run the Gradio demo: +```bash +pythonb -m gradio_demo.line_mat_gluestick +``` + +## πŸš— Inference + +Quickly start use our models for line segment detection by running the following command: +```bash +python -m predictor.predict --img $[IMAGE_PATH_OR_FODER] +``` + +You can also specify more params by: + +```bash +python -m predictor.predict \ + --ckpt $[MODEL_PATH] \ + --img $[IMAGE_PATH_OR_FODER] \ + --ext $[png/pdf/json] \ + --threshold 10 \ + --junction-hm 0.1 \ + --disable-show +``` + +```bash +OPTIONS: + --ckpt CKPT, -c CKPT + Path to the checkpoint file. + --img IMG, -i IMG Path to the image or folder containing images. + --ext EXT, -e EXT Output file extension (png/pdf/json). + --threshold THRESHOLD, -t THRESHOLD + Threshold for line segment detection. + --junction-hm JUNCTION_HM, -jh JUNCTION_HM + Junction heatmap threshold. + --num-junctions NUM_JUNCTIONS, -nj NUM_JUNCTIONS + Max number of junctions to detect. + --disable-show Disable showing the results. + --use_lsd Use LSD-Rectifier for line segment detection. + --use_nms Use Non-Maximum Suppression (NMS) for junction detection. +``` + + +## πŸ“– Related Third-party Projects + +- [HAWPv3](https://github.com/cherubicXN/hawp/tree/main) +- [DeepLSD](https://github.com/cvg/DeepLSD) +- [Progressive-x](https://github.com/danini/progressive-x/tree/vanishing-points) +- [GlueStick](https://github.com/cvg/GlueStick) +- [GlueFactory](https://github.com/cvg/glue-factory) +- [LiMAP](https://github.com/cvg/limap) + + +## πŸ“ Citation + +If you find our work useful in your research, please consider citing: + +```bash +@inproceedings{ScaleLSD, + title = {ScaleLSD: Scalable Deep Line Segment Detection Streamlined}, + author = {Zeran Ke and Bin Tan and Xianwei Zheng and Yujun Shen and Tianfu Wu and Nan Xue}, + booktitle = "IEEE Conference on Computer Vision and Pattern Recognition (CVPR)", + year = {2025}, +} +``` diff --git a/gradio_demo/inference.py b/gradio_demo/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..0a38ec1fdee2f7fc610bea60be5d186947269545 --- /dev/null +++ b/gradio_demo/inference.py @@ -0,0 +1,252 @@ +import torch +import cv2 +import os +import gradio as gr +import numpy as np +import random +from pathlib import Path +import json + +from scalelsd.ssl.models.detector import ScaleLSD +from scalelsd.base import show, WireframeGraph +from scalelsd.ssl.misc.train_utils import fix_seeds, load_scalelsd_model + +# Title for the Gradio interface +_TITLE = 'Gradio Demo of ScaleLSD for Structured Representation of Images' +MAX_SEED = 1000 + + +def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: + """random seed""" + if randomize_seed: + seed = random.randint(0, MAX_SEED) + return seed + +def stop_run(): + """stop run""" + return ( + gr.update(value="Run", variant="primary", visible=True), + gr.update(visible=False), + ) + +def process_image( + input_image, + model_name='scalelsd-vitbase-v2-train-sa1b.pt', + save_name='temp_output', + threshold=10, + junction_threshold_hm=0.008, + num_junctions_inference=512, + width=512, + height=512, + line_width=2, + juncs_size=4, + whitebg=0.0, + draw_junctions_only=False, + use_lsd=False, + use_nms=False, + edge_color='orange', + vertex_color='Cyan', + output_format='png', + seed=0, + randomize_seed=False +): + """core processing function for image inference""" + # set random seed + seed = int(randomize_seed_fn(seed, randomize_seed)) + fix_seeds(seed) + + # initialize model + ckpt = "models/" + model_name + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = load_scalelsd_model(ckpt, device) + + # set model parameters + model.junction_threshold_hm = junction_threshold_hm + model.num_junctions_inference = num_junctions_inference + + # transform input image + if isinstance(input_image, np.ndarray): + image = cv2.cvtColor(input_image, cv2.COLOR_RGB2GRAY) + else: + image = cv2.imread(input_image, 0) + + # resize + ori_shape = image.shape[:2] + image_resized = cv2.resize(image.copy(), (width, height)) + image_tensor = torch.from_numpy(image_resized).float() / 255.0 + image_tensor = image_tensor[None, None].to('cuda') + + # meta data + meta = { + 'width': ori_shape[1], + 'height': ori_shape[0], + 'filename': '', + 'use_lsd': use_lsd, + 'use_nms': use_nms, + } + + # inference + with torch.no_grad(): + outputs, _ = model(image_tensor, meta) + outputs = outputs[0] + + # visual results + painter = show.painters.HAWPainter() + painter.confidence_threshold = threshold + painter.line_width = line_width + painter.marker_size = juncs_size + if whitebg > 0.0: + show.Canvas.white_overlay = whitebg + + temp_folder = "temp_output" + os.makedirs(temp_folder, exist_ok=True) + fig_file = f"{temp_folder}/{save_name}.png" + with show.image_canvas(input_image, fig_file=fig_file) as ax: + if draw_junctions_only: + painter.draw_junctions(ax, outputs) + else: + painter.draw_wireframe(ax, outputs, edge_color=edge_color, vertex_color=vertex_color) + # read the result image + result_image = cv2.imread(fig_file) + + if output_format != 'png': + fig_file = f"{temp_folder}/{save_name}.{output_format}" + with show.image_canvas(input_image, fig_file=fig_file) as ax: + if draw_junctions_only: + painter.draw_junctions(ax, outputs) + else: + painter.draw_wireframe(ax, outputs, edge_color=edge_color, vertex_color=vertex_color) + + json_file = f"{temp_folder}/{save_name}.json" + indices = WireframeGraph.xyxy2indices(outputs['juncs_pred'],outputs['lines_pred']) + wireframe = WireframeGraph(outputs['juncs_pred'], outputs['juncs_score'], indices, outputs['lines_score'], outputs['width'], outputs['height']) + with open(json_file, 'w') as f: + json.dump(wireframe.jsonize(),f) + + + return result_image[:, :, ::-1], json_file, fig_file + + +def run_demo(): + """create the Gradio demo interface""" + css = """ + #col-container { + margin: 0 auto; + max-width: 800px; + } + """ + + with gr.Blocks(css=css, title=_TITLE) as demo: + with gr.Column(elem_id="col-container"): + gr.Markdown(f'# {_TITLE}') + gr.Markdown("Detect wireframe structures in images using ScaleLSD model") + + pid = gr.State() + figs_root = "assets/figs" + example_images = [os.path.join(figs_root, iname) for iname in os.listdir(figs_root)] + + with gr.Row(): + input_image = gr.Image(example_images[0], label="Input Image", type="numpy") + output_image = gr.Image(label="Detection Result") + + with gr.Row(): + run_btn = gr.Button(value="Run", variant="primary") + stop_btn = gr.Button(value="Stop", variant="stop", visible=False) + + with gr.Row(): + json_file = gr.File(label="Download JSON Output", type="filepath") + image_file = gr.File(label="Download Image Output", type="filepath") + + with gr.Accordion("Advanced Settings", open=True): + with gr.Row(): + model_name = gr.Dropdown( + [ckpt for ckpt in os.listdir('models') if ckpt.endswith('.pt')], + value='scalelsd-vitbase-v1-train-sa1b.pt', + label="Model Selection" + ) + + with gr.Row(): + save_name = gr.Textbox('temp_output', label="Save Name", placeholder="Name for saving output files") + + with gr.Row(): + with gr.Column(): + threshold = gr.Number(10, label="Line Threshold") + junction_threshold_hm = gr.Number(0.008, label="Junction Threshold") + num_junctions_inference = gr.Number(1024, label="Max Number of Junctions") + width = gr.Number(512, label="Input Width") + height = gr.Number(512, label="Input Height") + + with gr.Column(): + draw_junctions_only = gr.Checkbox(False, label="Show Junctions Only") + use_lsd = gr.Checkbox(False, label="Use LSD-Rectifier") + use_nms = gr.Checkbox(True, label="Use NMS") + output_format = gr.Dropdown( + ['png', 'jpg', 'pdf'], + value='png', + label="Output Format" + ) + whitebg = gr.Slider(0.0, 1.0, value=0.7, label="White Background Opacity") + line_width = gr.Number(2, label="Line Width") + juncs_size = gr.Number(8, label="Junctions Size") + + with gr.Row(): + edge_color = gr.Dropdown( + ['orange', 'midnightblue', 'red', 'green'], + value='orange', + label="Edge Color" + ) + vertex_color = gr.Dropdown( + ['Cyan', 'deeppink', 'yellow', 'purple'], + value='Cyan', + label="Vertex Color" + ) + + with gr.Row(): + randomize_seed = gr.Checkbox(False, label="Randomize Seed") + seed = gr.Slider(0, MAX_SEED, value=42, step=1, label="Seed") + + gr.Examples( + examples=example_images, + inputs=input_image, + ) + + # star event handlers + run_event = run_btn.click( + fn=process_image, + inputs=[ + input_image, + model_name, + save_name, + threshold, + junction_threshold_hm, + num_junctions_inference, + width, + height, + line_width, + juncs_size, + whitebg, + draw_junctions_only, + use_lsd, + use_nms, + edge_color, + vertex_color, + output_format, + seed, + randomize_seed + ], + outputs=[output_image, json_file, image_file], + ) + + # stop event handlers + stop_btn.click( + fn=stop_run, + outputs=[run_btn, stop_btn], + cancels=[run_event], + queue=False, + ) + + + return demo + +if __name__ == "__main__": + run_demo().launch() diff --git a/gradio_demo/line_mat_gluestick.py b/gradio_demo/line_mat_gluestick.py new file mode 100644 index 0000000000000000000000000000000000000000..04fbb099df7645378907a5c000478183b2e59043 --- /dev/null +++ b/gradio_demo/line_mat_gluestick.py @@ -0,0 +1,386 @@ +import argparse +import os +from os.path import join +import sys +import numpy as np +import cv2 +import torch +from matplotlib import pyplot as plt +from tqdm import tqdm +import gradio as gr +import random + +from gluestick import batch_to_np, numpy_image_to_torch, GLUESTICK_ROOT +from gluestick.drawing import plot_images, plot_lines, plot_color_line_matches, plot_keypoints, plot_matches + +from scalelsd.ssl.models.detector import ScaleLSD +from scalelsd.base import show, WireframeGraph +from scalelsd.ssl.datasets.transforms.homographic_transforms import sample_homography +from scalelsd.ssl.misc.train_utils import fix_seeds +from line_matching.two_view_pipeline import TwoViewPipeline + +from kornia.geometry import warp_perspective,transform_points + +class HADConfig: + num_iter = 1 + valid_border_margin = 3 + translation = True + rotation = True + scale = True + perspective = True + scaling_amplitude = 0.2 + perspective_amplitude_x = 0.2 + perspective_amplitude_y = 0.2 + allow_artifacts = False + patch_ratio = 0.85 +had_cfg = HADConfig() + +# Evaluation config +default_conf = { + 'name': 'two_view_pipeline', + 'use_lines': True, + 'extractor': { + 'name': 'wireframe', + 'sp_params': { + 'force_num_keypoints': False, + 'max_num_keypoints': 2048, + }, + 'wireframe_params': { + 'merge_points': True, + 'merge_line_endpoints': True, + # 'merge_line_endpoints': False, + }, + 'max_n_lines': 512, + }, + 'matcher': { + 'name': 'gluestick', + 'weights': str(GLUESTICK_ROOT / 'resources' / 'weights' / 'checkpoint_GlueStick_MD.tar'), + 'trainable': False, + }, + 'ground_truth': { + 'from_pose_depth': False, + } +} + +# Title for the Gradio interface +_TITLE = 'ScaleLSD-GlueStick Line Matching' +MAX_SEED = 1000 + +def sample_homographics(height, width): + + def scale_homography(H, stride): + H_scaled = H.clone() + H_scaled[:, :, 2, :2] *= stride + H_scaled[:, :, :2, 2] /= stride + return H_scaled + + homographic = sample_homography( + shape = (height, width), + perspective = had_cfg.perspective, + scaling = had_cfg.scale, + rotation = had_cfg.rotation, + translation = had_cfg.translation, + scaling_amplitude = had_cfg.scaling_amplitude, + perspective_amplitude_x = had_cfg.perspective_amplitude_x, + perspective_amplitude_y = had_cfg.perspective_amplitude_y, + patch_ratio = had_cfg.patch_ratio, + allow_artifacts = False + )[0] + + homographic = torch.from_numpy(homographic[None]).float().cuda() + homographic_inv = torch.inverse(homographic) + + H = { + 'h.1': homographic, + 'ih.1': homographic_inv, + } + + return H + +def trans_image_with_homograpy(image): + h, w = image.shape[:2] + H = sample_homographics(height=h, width=w) + + image_warped = warp_perspective(torch.Tensor(image).permute(2,0,1)[None].cuda(), H['h.1'], (h,w)) + image_warped_ = image_warped[0].permute(1,2,0).cpu().numpy().astype(np.uint8) + plt.imshow(image_warped_) + plt.show() + return image_warped_ + +def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: + """random seed""" + if randomize_seed: + seed = random.randint(0, MAX_SEED) + return seed + +def stop_run(): + """stop run""" + return ( + gr.update(value="Run", variant="primary", visible=True), + gr.update(visible=False), + ) + +def clear_image2(): + return None # returning None will clear the image component + +def process_image( + input_image1='assets/figs/sa_1119229.jpg', + input_image2=None, + model_name='scalelsd-vitbase-v1-train-sa1b.pt', + save_name='temp', + threshold=5, + junction_threshold_hm=0.008, + num_junctions_inference=4096, + width=512, + height=512, + line_width=2, + juncs_size=4, + whitebg=1.0, + draw_junctions_only=False, + use_lsd=False, + use_nms=False, + edge_color='midnightblue', + vertex_color='deeppink', + output_format='png', + seed=0, + randomize_seed=False +): + """core processing function for image inference""" + # set random seed + seed = int(randomize_seed_fn(seed, randomize_seed)) + fix_seeds(seed) + + conf = { + 'model_name': model_name, + 'threshold': threshold, + 'junction_threshold_hm': junction_threshold_hm, + 'num_junctions_inference': num_junctions_inference, + 'use_lsd': use_lsd, + 'use_nms': use_nms, + 'width': width, + 'height': height, + } + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + pipeline_model = TwoViewPipeline(default_conf).to(device).eval() + pipeline_model.extractor.update_conf(conf) + + saveto = f'temp_output/matching_results' + image1 = cv2.cvtColor(input_image1, cv2.COLOR_BGR2RGB) + cv2.imwrite(f'{saveto}/image.png', image1) + input_image1 = f'{saveto}/image.png' + if input_image2 is None: + image2 = trans_image_with_homograpy(image1) + else: + image2 = cv2.cvtColor(input_image2, cv2.COLOR_BGR2RGB) + cv2.imwrite(f'{saveto}/image2.png', image2) + input_image2 = f'{saveto}/image2.png' + + gray0 = cv2.imread(input_image1, 0) + gray1 = cv2.imread(input_image2, 0) + + torch_gray0, torch_gray1 = numpy_image_to_torch(gray0), numpy_image_to_torch(gray1) + torch_gray0, torch_gray1 = torch_gray0.to(device)[None], torch_gray1.to(device)[None] + + x = {'image0': torch_gray0, 'image1': torch_gray1} + pred = pipeline_model(x) + + pred = batch_to_np(pred) + kp0, kp1 = pred["keypoints0"], pred["keypoints1"] + m0 = pred["matches0"] + + line_seg0, line_seg1 = pred["lines0"], pred["lines1"] + line_matches = pred["line_matches0"] + + valid_matches = m0 != -1 + match_indices = m0[valid_matches] + matched_kps0 = kp0[valid_matches] + matched_kps1 = kp1[match_indices] + + valid_matches = line_matches != -1 + match_indices = line_matches[valid_matches] + matched_lines0 = line_seg0[valid_matches] + matched_lines1 = line_seg1[match_indices] + + img0, img1 = cv2.cvtColor(gray0, cv2.COLOR_GRAY2BGR), cv2.cvtColor(gray1, cv2.COLOR_GRAY2BGR) + + mat_file = f'{saveto}/{save_name}_mat.png' + plot_images([img0, img1], dpi=200, pad=2.0) + plot_lines([line_seg0, line_seg1], ps=4, lw=2) + plt.gcf().canvas.manager.set_window_title('Detected Lines') + # plt.tight_layout() + plt.savefig(mat_file) + det_image = cv2.imread(mat_file)[:,:,::-1] + + det_file = f'{saveto}/{save_name}_mat.png' + plot_images([img0, img1], dpi=200, pad=2.0) + plot_color_line_matches([matched_lines0, matched_lines1], lw=3) + plt.gcf().canvas.manager.set_window_title('Line Matches') + # plt.tight_layout() + plt.savefig(det_file) + mat_image = cv2.imread(det_file)[:,:,::-1] + + show.Canvas.white_overlay = whitebg + painter = show.painters.HAWPainter() + + fig_file = f'{saveto}/{save_name}_det1.png' + outputs = {'lines_pred': line_seg0.reshape(-1,4)} + with show.image_canvas(input_image1, fig_file=fig_file) as ax: + painter.draw_wireframe(ax,outputs, edge_color=edge_color, vertex_color=vertex_color) + det1_image = cv2.imread(fig_file)[:,:,::-1] + + fig_file = f'{saveto}/{save_name}_det2.png' + outputs = {'lines_pred': line_seg1.reshape(-1,4)} + with show.image_canvas(input_image2, fig_file=fig_file) as ax: + painter.draw_wireframe(ax,outputs, edge_color=edge_color, vertex_color=vertex_color) + det2_image = cv2.imread(fig_file)[:,:,::-1] + + return image2[:,:,::-1], mat_image, det_image, det1_image, det2_image, mat_file, det_file + + +def demo(): + """create the Gradio demo interface""" + css = """ + #col-container { + margin: 0 auto; + max-width: 800px; + } + """ + + with gr.Blocks(css=css, title=_TITLE) as demo: + with gr.Column(elem_id="col-container"): + gr.Markdown(f'# {_TITLE}') + gr.Markdown("Detect wireframe structures in images using ScaleLSD model") + + pid = gr.State() + figs_root = "assets/mat_figs" + example_single = [os.path.join(figs_root, 'single', iname) for iname in os.listdir(figs_root+'/single')] + example_pairs = [[img, None] for img in example_single] + example_pairs += [ + [os.path.join(figs_root, 'pairs', f'ref_{i}.png'), + os.path.join(figs_root, 'pairs', f'tgt_{i}.png')] + for i in [10, 72, 76, 95, 149, 151] + ] + + with gr.Row(): + input_image1 = gr.Image(example_pairs[0][0], label="Input Image1", type="numpy") + input_image2 = gr.Image(label="Input Image2", type="numpy") + + with gr.Row(): + mat_images = gr.Image(label="Matching Results") + with gr.Row(): + det_images = gr.Image(label="Detection Results") + with gr.Row(): + det_image1 = gr.Image(label="Detection1") + det_image2 = gr.Image(label="Detection2") + + with gr.Row(): + run_btn = gr.Button(value="Run", variant="primary") + stop_btn = gr.Button(value="Stop", variant="stop", visible=False) + + with gr.Row(): + mat_file = gr.File(label="Download Matching Result", type="filepath") + det_file = gr.File(label="Download Detection Result", type="filepath") + + with gr.Accordion("Advanced Settings", open=True): + with gr.Row(): + model_name = gr.Dropdown( + [ckpt for ckpt in os.listdir('models') if ckpt.endswith('.pt')], + value='scalelsd-vitbase-v1-train-sa1b.pt', + label="Model Selection" + ) + + with gr.Row(): + save_name = gr.Textbox('temp_output', label="Save Name", placeholder="Name for saving output files") + + with gr.Row(): + with gr.Column(): + threshold = gr.Number(10, label="Line Threshold") + junction_threshold_hm = gr.Number(0.008, label="Junction Threshold") + num_junctions_inference = gr.Number(1024, label="Max Number of Junctions") + width = gr.Number(512, label="Input Width") + height = gr.Number(512, label="Input Height") + + with gr.Column(): + draw_junctions_only = gr.Checkbox(False, label="Show Junctions Only") + use_lsd = gr.Checkbox(False, label="Use LSD-Rectifier") + use_nms = gr.Checkbox(True, label="Use NMS") + output_format = gr.Dropdown( + ['png', 'jpg', 'pdf'], + value='png', + label="Output Format" + ) + whitebg = gr.Slider(0.0, 1.0, value=1.0, label="White Background Opacity") + line_width = gr.Number(2, label="Line Width") + juncs_size = gr.Number(8, label="Junctions Size") + + with gr.Row(): + edge_color = gr.Dropdown( + ['orange', 'midnightblue', 'red', 'green'], + value='midnightblue', + label="Edge Color" + ) + vertex_color = gr.Dropdown( + ['Cyan', 'deeppink', 'yellow', 'purple'], + value='deeppink', + label="Vertex Color" + ) + + with gr.Row(): + randomize_seed = gr.Checkbox(False, label="Randomize Seed") + seed = gr.Slider(0, MAX_SEED, value=42, step=1, label="Seed") + + gr.Examples( + examples=example_pairs, + inputs=[input_image1, input_image2] + ) + + # star event handlers + run_event = run_btn.click( + fn=process_image, + inputs=[ + input_image1, + input_image2, + model_name, + save_name, + threshold, + junction_threshold_hm, + num_junctions_inference, + width, + height, + line_width, + juncs_size, + whitebg, + draw_junctions_only, + use_lsd, + use_nms, + edge_color, + vertex_color, + output_format, + seed, + randomize_seed + ], + outputs=[input_image2, mat_images, det_images, det_image1, det_image2, mat_file, det_file], + ) + + # stop event handlers + stop_btn.click( + fn=stop_run, + outputs=[run_btn, stop_btn], + cancels=[run_event], + queue=False, + ) + + # When image1 changes, image2 is cleared + input_image1.change( + fn=clear_image2, + outputs=input_image2 + ) + + + return demo + +if __name__ == "__main__": + # ε―εŠ¨εΊ”η”¨ + demo = demo() + demo.launch() diff --git a/line_matching/run.py b/line_matching/run.py new file mode 100644 index 0000000000000000000000000000000000000000..0f7d50c2186cb1bf9ddd2bbc6ebe6f8888c0511d --- /dev/null +++ b/line_matching/run.py @@ -0,0 +1,191 @@ +import argparse +import os +from os.path import join +import sys +import numpy as np +import cv2 +import torch +from matplotlib import pyplot as plt +from tqdm import tqdm + +from gluestick import batch_to_np, numpy_image_to_torch, GLUESTICK_ROOT +from gluestick.drawing import plot_images, plot_lines, plot_color_line_matches, plot_keypoints, plot_matches +from line_matching.two_view_pipeline import TwoViewPipeline + +from scalelsd.base import show, WireframeGraph +from scalelsd.ssl.datasets.transforms.homographic_transforms import sample_homography +from kornia.geometry import warp_perspective,transform_points + +class HADConfig: + num_iter = 1 + valid_border_margin = 3 + translation = True + rotation = True + scale = True + perspective = True + scaling_amplitude = 0.2 + perspective_amplitude_x = 0.2 + perspective_amplitude_y = 0.2 + allow_artifacts = False + patch_ratio = 0.85 +had_cfg = HADConfig() + +def sample_homographics(height, width): + + def scale_homography(H, stride): + H_scaled = H.clone() + H_scaled[:, :, 2, :2] *= stride + H_scaled[:, :, :2, 2] /= stride + return H_scaled + + homographic = sample_homography( + shape = (height, width), + perspective = had_cfg.perspective, + scaling = had_cfg.scale, + rotation = had_cfg.rotation, + translation = had_cfg.translation, + scaling_amplitude = had_cfg.scaling_amplitude, + perspective_amplitude_x = had_cfg.perspective_amplitude_x, + perspective_amplitude_y = had_cfg.perspective_amplitude_y, + patch_ratio = had_cfg.patch_ratio, + allow_artifacts = False + )[0] + + homographic = torch.from_numpy(homographic[None]).float().cuda() + homographic_inv = torch.inverse(homographic) + + H = { + 'h.1': homographic, + 'ih.1': homographic_inv, + } + + return H + +def trans_image_with_homograpy(image): + h, w = image.shape[:2] + H = sample_homographics(height=h, width=w) + + image_warped = warp_perspective(torch.Tensor(image).permute(2,0,1)[None].cuda(), H['h.1'], (h,w)) + image_warped_ = image_warped[0].permute(1,2,0).cpu().numpy().astype(np.uint8) + plt.imshow(image_warped_) + plt.show() + return image_warped_ + + +def main(): + # Parse input parameters + parser = argparse.ArgumentParser( + prog='GlueStick Demo', + description='Demo app to show the point and line matches obtained by GlueStick') + parser.add_argument('-img1', default='assets/figs/sa_1119229.jpg') + parser.add_argument('-img2', default=None) + parser.add_argument('--max_pts', type=int, default=1000) + parser.add_argument('--max_lines', type=int, default=300) + parser.add_argument('--model', type=str, default='models/paper-sa1b-997pkgs-model.pt') + args = parser.parse_args() + + # important + if args.img1 is None and args.img2 is None: + raise ValueError("Input at least one path of image1 or image2") + + # Evaluation config + conf = { + 'name': 'two_view_pipeline', + 'use_lines': True, + 'extractor': { + 'name': 'wireframe', + 'sp_params': { + 'force_num_keypoints': False, + 'max_num_keypoints': args.max_pts, + }, + 'wireframe_params': { + 'merge_points': True, + 'merge_line_endpoints': True, + # 'merge_line_endpoints': False, + }, + 'max_n_lines': args.max_lines, + }, + 'matcher': { + 'name': 'gluestick', + 'weights': str(GLUESTICK_ROOT / 'resources' / 'weights' / 'checkpoint_GlueStick_MD.tar'), + 'trainable': False, + }, + 'ground_truth': { + 'from_pose_depth': False, + } + } + + device = 'cuda' if torch.cuda.is_available() else 'cpu' + pipeline_model = TwoViewPipeline(conf).to(device).eval() + pipeline_model.extractor.update_conf(None) + + saveto = f'temp_output/matching_results' + os.makedirs(saveto, exist_ok=True) + + image1 = cv2.cvtColor(cv2.imread(args.img1), cv2.COLOR_BGR2RGB) + if args.img2 is None: + image2 = trans_image_with_homograpy(image1) + cv2.imwrite(f'{saveto}/warped_image.png', image2) + args.img2 = f'{saveto}/warped_image.png' + + gray0 = cv2.imread(args.img1, 0) + gray1 = cv2.imread(args.img2, 0) + + torch_gray0, torch_gray1 = numpy_image_to_torch(gray0), numpy_image_to_torch(gray1) + torch_gray0, torch_gray1 = torch_gray0.to(device)[None], torch_gray1.to(device)[None] + + x = {'image0': torch_gray0, 'image1': torch_gray1} + pred = pipeline_model(x) + + pred = batch_to_np(pred) + kp0, kp1 = pred["keypoints0"], pred["keypoints1"] + m0 = pred["matches0"] + + line_seg0, line_seg1 = pred["lines0"], pred["lines1"] + line_matches = pred["line_matches0"] + + valid_matches = m0 != -1 + match_indices = m0[valid_matches] + matched_kps0 = kp0[valid_matches] + matched_kps1 = kp1[match_indices] + + valid_matches = line_matches != -1 + match_indices = line_matches[valid_matches] + matched_lines0 = line_seg0[valid_matches] + matched_lines1 = line_seg1[match_indices] + + # Plot the matches + gray0 = cv2.imread(args.img1, 0) + gray1 = cv2.imread(args.img2, 0) + img0, img1 = cv2.cvtColor(gray0, cv2.COLOR_GRAY2BGR), cv2.cvtColor(gray1, cv2.COLOR_GRAY2BGR) + + plot_images([img0, img1], dpi=200, pad=2.0) + plot_lines([line_seg0, line_seg1], ps=4, lw=2) + plt.gcf().canvas.manager.set_window_title('Detected Lines') + # plt.tight_layout() + plt.savefig(f'{saveto}/det.png') + + plot_images([img0, img1], dpi=200, pad=2.0) + plot_color_line_matches([matched_lines0, matched_lines1], lw=3) + plt.gcf().canvas.manager.set_window_title('Line Matches') + # plt.tight_layout() + plt.savefig(f'{saveto}/mat.png') + + whitebg = 1 + show.Canvas.white_overlay = whitebg + painter = show.painters.HAWPainter() + + fig_file = f'{saveto}/det1.png' + outputs = {'lines_pred': line_seg0.reshape(-1,4)} + with show.image_canvas(args.img1, fig_file=fig_file) as ax: + # painter.draw_wireframe(ax,outputs, edge_color='orange', vertex_color='Cyan') + painter.draw_wireframe(ax,outputs, edge_color='midnightblue', vertex_color='deeppink') + fig_file = f'{saveto}/det2.png' + outputs = {'lines_pred': line_seg1.reshape(-1,4)} + with show.image_canvas(args.img2, fig_file=fig_file) as ax: + painter.draw_wireframe(ax,outputs, edge_color='midnightblue', vertex_color='deeppink') + + + +if __name__ == '__main__': + main() diff --git a/line_matching/run_list.py b/line_matching/run_list.py new file mode 100644 index 0000000000000000000000000000000000000000..170b8b7acab892504b58e60356e5dbdf7d984e13 --- /dev/null +++ b/line_matching/run_list.py @@ -0,0 +1,144 @@ +import argparse +import os +from os.path import join +import sys + +import cv2 +import torch +from matplotlib import pyplot as plt +from tqdm import tqdm + +from gluestick import batch_to_np, numpy_image_to_torch, GLUESTICK_ROOT +from gluestick.drawing import plot_images, plot_lines, plot_color_line_matches, plot_keypoints, plot_matches +# from gluestick.models.two_view_pipeline import TwoViewPipeline +from line_matching.two_view_pipeline import TwoViewPipeline + +from scalelsd.base import show, WireframeGraph + +def main(): + # Parse input parameters + parser = argparse.ArgumentParser( + prog='GlueStick Demo', + description='Demo app to show the point and line matches obtained by GlueStick') + parser.add_argument('-inum', default=None, type=int) + parser.add_argument('-imax', default=None, type=int) + parser.add_argument('-img1', default=join('resources' + os.path.sep + 'img1.jpg')) + parser.add_argument('-img2', default=join('resources' + os.path.sep + 'img2.jpg')) + parser.add_argument('--max_pts', type=int, default=1000) + parser.add_argument('--max_lines', type=int, default=300) + parser.add_argument('--model', default='scalelsd', type=str) + parser.add_argument('--test_root', type=str, default='data-ssl/0images-pre/') + args = parser.parse_args() + + # Evaluation config + conf = { + 'name': 'two_view_pipeline', + 'use_lines': True, + 'extractor': { + 'name': 'wireframe', + 'sp_params': { + 'force_num_keypoints': False, + 'max_num_keypoints': args.max_pts, + }, + 'wireframe_params': { + 'merge_points': True, + 'merge_line_endpoints': True, + # 'merge_line_endpoints': False, + }, + 'max_n_lines': args.max_lines, + }, + 'matcher': { + 'name': 'gluestick', + 'weights': str(GLUESTICK_ROOT / 'resources' / 'weights' / 'checkpoint_GlueStick_MD.tar'), + 'trainable': False, + }, + 'ground_truth': { + 'from_pose_depth': False, + } + } + + device = 'cuda' if torch.cuda.is_available() else 'cpu' + pipeline_model = TwoViewPipeline(conf).to(device).eval() + + pipeline_model.extractor.update_conf(None) + + md = args.model + + root = args.test_root + if args.inum is not None: + ids = [args.inum] + elif args.imax is not None: + ids = range(args.inum, args.imax+1) + else: + l_imgs = int(len(os.listdir(root))/2) + ids = range(l_imgs) + + for id in tqdm(ids): + saveto = f'temp_output/matching_results/{md}/{id}' + os.makedirs(saveto, exist_ok=True) + + args.img1 = root + f'ref_{str(id)}.png' + args.img2 = root + f'tgt_{str(id)}.png' + + gray0 = cv2.imread(args.img1, 0) + gray1 = cv2.imread(args.img2, 0) + + torch_gray0, torch_gray1 = numpy_image_to_torch(gray0), numpy_image_to_torch(gray1) + torch_gray0, torch_gray1 = torch_gray0.to(device)[None], torch_gray1.to(device)[None] + + x = {'image0': torch_gray0, 'image1': torch_gray1} + pred = pipeline_model(x) + + pred = batch_to_np(pred) + kp0, kp1 = pred["keypoints0"], pred["keypoints1"] + m0 = pred["matches0"] + + line_seg0, line_seg1 = pred["lines0"], pred["lines1"] + line_matches = pred["line_matches0"] + + valid_matches = m0 != -1 + match_indices = m0[valid_matches] + matched_kps0 = kp0[valid_matches] + matched_kps1 = kp1[match_indices] + + valid_matches = line_matches != -1 + match_indices = line_matches[valid_matches] + matched_lines0 = line_seg0[valid_matches] + matched_lines1 = line_seg1[match_indices] + + # Plot the matches + gray0 = cv2.imread(args.img1, 0) + gray1 = cv2.imread(args.img2, 0) + img0, img1 = cv2.cvtColor(gray0, cv2.COLOR_GRAY2BGR), cv2.cvtColor(gray1, cv2.COLOR_GRAY2BGR) + + plot_images([img0, img1], dpi=200, pad=2.0) + plot_lines([line_seg0, line_seg1], ps=4, lw=2) + plt.gcf().canvas.manager.set_window_title('Detected Lines') + # plt.tight_layout() + plt.savefig(f'{saveto}/{md}_det_{id}.png') + + plot_images([img0, img1], dpi=200, pad=2.0) + plot_color_line_matches([matched_lines0, matched_lines1], lw=3) + plt.gcf().canvas.manager.set_window_title('Line Matches') + # plt.tight_layout() + plt.savefig(f'{saveto}/{md}_mat_{id}.png') + + whitebg = 1 + show.Canvas.white_overlay = whitebg + painter = show.painters.HAWPainter() + + fig_file = f'{saveto}/{md}_det1.png' + outputs = {'lines_pred': line_seg0.reshape(-1,4)} + with show.image_canvas(args.img1, fig_file=fig_file) as ax: + # painter.draw_wireframe(ax,outputs, edge_color='orange', vertex_color='Cyan') + painter.draw_wireframe(ax,outputs, edge_color='midnightblue', vertex_color='deeppink') + fig_file = f'{saveto}/{md}_det2.png' + outputs = {'lines_pred': line_seg1.reshape(-1,4)} + with show.image_canvas(args.img2, fig_file=fig_file) as ax: + # painter.draw_wireframe(ax,outputs, edge_color='orange', vertex_color='Cyan') + painter.draw_wireframe(ax,outputs, edge_color='midnightblue', vertex_color='deeppink') + + + +if __name__ == '__main__': + main() diff --git a/line_matching/two_view_pipeline.py b/line_matching/two_view_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..c08d5858732d3b56cf86a96d1ff1943b794edaa5 --- /dev/null +++ b/line_matching/two_view_pipeline.py @@ -0,0 +1,167 @@ +""" +A two-view sparse feature matching pipeline. + +This model contains sub-models for each step: + feature extraction, feature matching, outlier filtering, pose estimation. +Each step is optional, and the features or matches can be provided as input. +Default: SuperPoint with nearest neighbor matching. + +Convention for the matches: m0[i] is the index of the keypoint in image 1 +that corresponds to the keypoint i in image 0. m0[i] = -1 if i is unmatched. +""" + +import numpy as np +import torch + +from gluestick import get_model +from gluestick.models.base_model import BaseModel +from line_matching.wireframe import SPWireframeDescriptor + + +def keep_quadrant_kp_subset(keypoints, scores, descs, h, w): + """Keep only keypoints in one of the four quadrant of the image.""" + h2, w2 = h // 2, w // 2 + w_x = np.random.choice([0, w2]) + w_y = np.random.choice([0, h2]) + valid_mask = ((keypoints[..., 0] >= w_x) + & (keypoints[..., 0] < w_x + w2) + & (keypoints[..., 1] >= w_y) + & (keypoints[..., 1] < w_y + h2)) + keypoints = keypoints[valid_mask][None] + scores = scores[valid_mask][None] + descs = descs.permute(0, 2, 1)[valid_mask].t()[None] + return keypoints, scores, descs + + +def keep_random_kp_subset(keypoints, scores, descs, num_selected): + """Keep a random subset of keypoints.""" + num_kp = keypoints.shape[1] + selected_kp = torch.randperm(num_kp)[:num_selected] + keypoints = keypoints[:, selected_kp] + scores = scores[:, selected_kp] + descs = descs[:, :, selected_kp] + return keypoints, scores, descs + + +def keep_best_kp_subset(keypoints, scores, descs, num_selected): + """Keep the top num_selected best keypoints.""" + sorted_indices = torch.sort(scores, dim=1)[1] + selected_kp = sorted_indices[:, -num_selected:] + keypoints = torch.gather(keypoints, 1, + selected_kp[:, :, None].repeat(1, 1, 2)) + scores = torch.gather(scores, 1, selected_kp) + descs = torch.gather(descs, 2, + selected_kp[:, None].repeat(1, descs.shape[1], 1)) + return keypoints, scores, descs + + +class TwoViewPipeline(BaseModel): + default_conf = { + 'extractor': { + 'name': 'superpoint', + 'trainable': False, + }, + 'use_lines': False, + 'use_points': True, + 'randomize_num_kp': False, + 'detector': {'name': None}, + 'descriptor': {'name': None}, + 'matcher': {'name': 'nearest_neighbor_matcher'}, + 'filter': {'name': None}, + 'solver': {'name': None}, + 'ground_truth': { + 'from_pose_depth': False, + 'from_homography': False, + 'th_positive': 3, + 'th_negative': 5, + 'reward_positive': 1, + 'reward_negative': -0.25, + 'is_likelihood_soft': True, + 'p_random_occluders': 0, + 'n_line_sampled_pts': 50, + 'line_perp_dist_th': 5, + 'overlap_th': 0.2, + 'min_visibility_th': 0.5 + }, + } + required_data_keys = ['image0', 'image1'] + strict_conf = False # need to pass new confs to children models + components = [ + 'extractor', 'detector', 'descriptor', 'matcher', 'filter', 'solver'] + + def _init(self, conf): + if conf.extractor.name: + self.extractor = SPWireframeDescriptor(conf.extractor) + + if conf.matcher.name: + self.matcher = get_model(conf.matcher.name)(conf.matcher) + else: + self.required_data_keys += ['matches0'] + + if conf.filter.name: + self.filter = get_model(conf.filter.name)(conf.filter) + + if conf.solver.name: + self.solver = get_model(conf.solver.name)(conf.solver) + + def _forward(self, data): + + def process_siamese(data, i): + data_i = {k[:-1]: v for k, v in data.items() if k[-1] == i} + if self.conf.extractor.name: + pred_i = self.extractor(data_i) + else: + pred_i = {} + if self.conf.detector.name: + pred_i = self.detector(data_i) + else: + for k in ['keypoints', 'keypoint_scores', 'descriptors', + 'lines', 'line_scores', 'line_descriptors', + 'valid_lines']: + if k in data_i: + pred_i[k] = data_i[k] + if self.conf.descriptor.name: + pred_i = { + **pred_i, **self.descriptor({**data_i, **pred_i})} + return pred_i + + pred0 = process_siamese(data, '0') + pred1 = process_siamese(data, '1') + + pred = {**{k + '0': v for k, v in pred0.items()}, + **{k + '1': v for k, v in pred1.items()}} + + if self.conf.matcher.name: + pred = {**pred, **self.matcher({**data, **pred})} + + if self.conf.filter.name: + pred = {**pred, **self.filter({**data, **pred})} + + if self.conf.solver.name: + pred = {**pred, **self.solver({**data, **pred})} + + return pred + + def loss(self, pred, data): + losses = {} + total = 0 + for k in self.components: + if self.conf[k].name: + try: + losses_ = getattr(self, k).loss(pred, {**pred, **data}) + except NotImplementedError: + continue + losses = {**losses, **losses_} + total = losses_['total'] + total + return {**losses, 'total': total} + + def metrics(self, pred, data): + metrics = {} + for k in self.components: + if self.conf[k].name: + try: + metrics_ = getattr(self, k).metrics(pred, {**pred, **data}) + except NotImplementedError: + continue + metrics = {**metrics, **metrics_} + return metrics diff --git a/line_matching/wireframe.py b/line_matching/wireframe.py new file mode 100644 index 0000000000000000000000000000000000000000..66c7c04837050828973be65e0d0598d9d11f34b4 --- /dev/null +++ b/line_matching/wireframe.py @@ -0,0 +1,341 @@ +import numpy as np +import torch +from pytlsd import lsd +from sklearn.cluster import DBSCAN +import sys + +from gluestick.models.base_model import BaseModel +from gluestick.models.superpoint import SuperPoint, sample_descriptors +from gluestick.geometry import warp_lines_torch + +from pathlib import Path +import copy, cv2 +import os, glob +import scalelsd +from scalelsd.ssl.models.detector import ScaleLSD +from scalelsd.ssl.misc.train_utils import fix_seeds, load_scalelsd_model + + +def lines_to_wireframe(lines, line_scores, all_descs, conf): + """ Given a set of lines, their score and dense descriptors, + merge close-by endpoints and compute a wireframe defined by + its junctions and connectivity. + Returns: + junctions: list of [num_junc, 2] tensors listing all wireframe junctions + junc_scores: list of [num_junc] tensors with the junction score + junc_descs: list of [dim, num_junc] tensors with the junction descriptors + connectivity: list of [num_junc, num_junc] bool arrays with True when 2 junctions are connected + new_lines: the new set of [b_size, num_lines, 2, 2] lines + lines_junc_idx: a [b_size, num_lines, 2] tensor with the indices of the junctions of each endpoint + num_true_junctions: a list of the number of valid junctions for each image in the batch, + i.e. before filling with random ones + """ + b_size, _, _, _ = all_descs.shape + device = lines.device + endpoints = lines.reshape(b_size, -1, 2) + + (junctions, junc_scores, junc_descs, connectivity, new_lines, + lines_junc_idx, num_true_junctions) = [], [], [], [], [], [], [] + for bs in range(b_size): + # Cluster the junctions that are close-by + db = DBSCAN(eps=conf.nms_radius, min_samples=1).fit( + endpoints[bs].cpu().numpy()) + clusters = db.labels_ + n_clusters = len(set(clusters)) + num_true_junctions.append(n_clusters) + + # Compute the average junction and score for each cluster + clusters = torch.tensor(clusters, dtype=torch.long, + device=device) + new_junc = torch.zeros(n_clusters, 2, dtype=torch.float, + device=device) + new_junc.scatter_reduce_(0, clusters[:, None].repeat(1, 2), + endpoints[bs], reduce='mean', + include_self=False) + junctions.append(new_junc) + new_scores = torch.zeros(n_clusters, dtype=torch.float, device=device) + new_scores.scatter_reduce_( + 0, clusters, torch.repeat_interleave(line_scores[bs], 2), + reduce='mean', include_self=False) + junc_scores.append(new_scores) + + # Compute the new lines + new_lines.append(junctions[-1][clusters].reshape(-1, 2, 2)) + lines_junc_idx.append(clusters.reshape(-1, 2)) + + # Compute the junction connectivity + junc_connect = torch.eye(n_clusters, dtype=torch.bool, + device=device) + pairs = clusters.reshape(-1, 2) # these pairs are connected by a line + junc_connect[pairs[:, 0], pairs[:, 1]] = True + junc_connect[pairs[:, 1], pairs[:, 0]] = True + connectivity.append(junc_connect) + + # Interpolate the new junction descriptors + junc_descs.append(sample_descriptors( + junctions[-1][None], all_descs[bs:(bs + 1)], 8)[0]) + + new_lines = torch.stack(new_lines, dim=0) + lines_junc_idx = torch.stack(lines_junc_idx, dim=0) + return (junctions, junc_scores, junc_descs, connectivity, + new_lines, lines_junc_idx, num_true_junctions) + + +class SPWireframeDescriptor(BaseModel): + default_conf = { + 'sp_params': { + 'has_detector': True, + 'has_descriptor': True, + 'descriptor_dim': 256, + 'trainable': False, + + # Inference + 'return_all': True, + 'sparse_outputs': True, + 'nms_radius': 4, + 'detection_threshold': 0.005, + 'max_num_keypoints': 1000, + 'force_num_keypoints': True, + 'remove_borders': 4, + }, + 'wireframe_params': { + 'merge_points': True, + 'merge_line_endpoints': True, + 'nms_radius': 3, + 'max_n_junctions': 500, + }, + 'max_n_lines': 250, + 'min_length': 15, + } + required_data_keys = ['image'] + + def _init(self, conf): + self.conf = conf + self.sp = SuperPoint(conf.sp_params) + self.extr_conf = {} + + def detect_lsd_lines(self, x, max_n_lines=None): + if max_n_lines is None: + max_n_lines = self.conf.max_n_lines + lines, scores, valid_lines = [], [], [] + for b in range(len(x)): + # For each image on batch + img = (x[b].squeeze().cpu().numpy() * 255).astype(np.uint8) + if max_n_lines is None: + b_segs = lsd(img) + else: + for s in [0.3, 0.4, 0.5, 0.7, 0.8, 1.0]: + b_segs = lsd(img, scale=s) + if len(b_segs) >= max_n_lines: + break + + segs_length = np.linalg.norm(b_segs[:, 2:4] - b_segs[:, 0:2], axis=1) + # Remove short lines + b_segs = b_segs[segs_length >= self.conf.min_length] + segs_length = segs_length[segs_length >= self.conf.min_length] + b_scores = b_segs[:, -1] * np.sqrt(segs_length) + # Take the most relevant segments with + indices = np.argsort(-b_scores) + if max_n_lines is not None: + indices = indices[:max_n_lines] + lines.append(torch.from_numpy(b_segs[indices, :4].reshape(-1, 2, 2))) + scores.append(torch.from_numpy(b_scores[indices])) + valid_lines.append(torch.ones_like(scores[-1], dtype=torch.bool)) + + lines = torch.stack(lines).to(x) + scores = torch.stack(scores).to(x) + valid_lines = torch.stack(valid_lines).to(x.device) + return lines, scores, valid_lines + + def update_conf(self, conf): + self.extr_conf = conf + + def _forward(self, data): + b_size, _, h, w = data['image'].shape + device = data['image'].device + # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + if not self.conf.sp_params.force_num_keypoints: + assert b_size == 1, "Only batch size of 1 accepted for non padded inputs" + + # Line detection + if 'lines' not in data or 'line_scores' not in data: + if self.extr_conf is None: + ckpt = 'models/scalelsd-vitbase-v1-train-sa1b.pt' + model = load_scalelsd_model(ckpt, device) + model.junction_threshold_hm = 0.008 + threshold = 5 + model.num_junctions_inference = 4096 + size = 512 + image = data['image'] + image_size = image.shape[-2:] + image_np = image[0,0].cpu().numpy() + image_cp = copy.deepcopy(image_np) + image_torch = torch.from_numpy(cv2.resize(image_cp, (size, size))).float() + image_cuda = image_torch[None,None].to(device) + meta = { + 'width': image_size[1], + 'height':image_size[0], + 'filename': '', + 'use_lsd': False, + 'use_nms': False, + } + outputs, _ = model(image_cuda, meta) + lines = outputs[0]['lines_pred'] + line_scores = outputs[0]['lines_score'] + lines = lines[line_scores>=threshold] + line_scores = line_scores[line_scores>=threshold][None] + elif self.extr_conf['model_name'] != 'lsd': + # initialize model + ckpt = "models/" + self.extr_conf['model_name'] + model = load_scalelsd_model(ckpt, device) + # set model parameters + model.junction_threshold_hm = self.extr_conf['junction_threshold_hm'] + model.num_junctions_inference = self.extr_conf['num_junctions_inference'] + width, height = self.extr_conf['width'], self.extr_conf['height'] + + image = data['image'] + image_size = image.shape[-2:] + image_np = image[0,0].cpu().numpy() + image_cp = copy.deepcopy(image_np) + image_torch = torch.from_numpy(cv2.resize(image_cp, (width, height))).float() + image_cuda = image_torch[None,None].to(device) + meta = { + 'width': image_size[1], + 'height':image_size[0], + 'filename': '', + 'use_lsd': self.extr_conf['use_lsd'], + 'use_nms': self.extr_conf['use_nms'], + } + outputs, _ = model(image_cuda, meta) + lines = outputs[0]['lines_pred'] + line_scores = outputs[0]['lines_score'] + lines = lines[line_scores>=self.extr_conf['threshold']] + line_scores = line_scores[line_scores>=self.extr_conf['threshold']][None] + else: + if 'original_img' in data: + # Detect more lines, because when projecting them to the image most of them will be discarded + lines, line_scores, valid_lines = self.detect_lsd_lines( + data['original_img'], self.conf.max_n_lines * 3) + # Apply the same transformation that is applied in homography_adaptation + lines, valid_lines2 = warp_lines_torch(lines, data['H'], False, data['image'].shape[-2:]) + valid_lines = valid_lines & valid_lines2 + lines[~valid_lines] = -1 + line_scores[~valid_lines] = 0 + # Re-sort the line segments to pick the ones that are inside the image and have bigger score + sorted_scores, sorting_indices = torch.sort(line_scores, dim=-1, descending=True) + line_scores = sorted_scores[:, :self.conf.max_n_lines] + sorting_indices = sorting_indices[:, :self.conf.max_n_lines] + lines = torch.take_along_dim(lines, sorting_indices[..., None, None], 1) + valid_lines = torch.take_along_dim(valid_lines, sorting_indices, 1) + else: + lines, line_scores, valid_lines = self.detect_lsd_lines(data['image'],max_n_lines=1000000) + + else: + lines, line_scores, valid_lines = data['lines'], data['line_scores'], data['valid_lines'] + if line_scores.shape[-1] != 0: + line_scores /= (line_scores.new_tensor(1e-8) + line_scores.max(dim=1).values[:, None]) + + # SuperPoint prediction + pred = self.sp(data) + + # Remove keypoints that are too close to line endpoints + if self.conf.wireframe_params.merge_points: + kp = pred['keypoints'] + line_endpts = lines.reshape(b_size, -1, 2) + dist_pt_lines = torch.norm( + kp[:, :, None] - line_endpts[:, None], dim=-1) + # For each keypoint, mark it as valid or to remove + pts_to_remove = torch.any( + dist_pt_lines < self.conf.sp_params.nms_radius, dim=2) + # Simply remove them (we assume batch_size = 1 here) + assert len(kp) == 1 + pred['keypoints'] = pred['keypoints'][0][~pts_to_remove[0]][None] + pred['keypoint_scores'] = pred['keypoint_scores'][0][~pts_to_remove[0]][None] + pred['descriptors'] = pred['descriptors'][0].T[~pts_to_remove[0]].T[None] + + # Connect the lines together to form a wireframe + orig_lines = lines.clone() + if self.conf.wireframe_params.merge_line_endpoints and len(lines[0]) > 0: + # Merge first close-by endpoints to connect lines + (line_points, line_pts_scores, line_descs, line_association, + lines, lines_junc_idx, num_true_junctions) = lines_to_wireframe( + lines, line_scores, pred['all_descriptors'], + conf=self.conf.wireframe_params) + + # Add the keypoints to the junctions and fill the rest with random keypoints + (all_points, all_scores, all_descs, + pl_associativity) = [], [], [], [] + for bs in range(b_size): + all_points.append(torch.cat( + [line_points[bs], pred['keypoints'][bs]], dim=0)) + all_scores.append(torch.cat( + [line_pts_scores[bs], pred['keypoint_scores'][bs]], dim=0)) + all_descs.append(torch.cat( + [line_descs[bs], pred['descriptors'][bs]], dim=1)) + + associativity = torch.eye(len(all_points[-1]), dtype=torch.bool, device=device) + associativity[:num_true_junctions[bs], :num_true_junctions[bs]] = \ + line_association[bs][:num_true_junctions[bs], :num_true_junctions[bs]] + pl_associativity.append(associativity) + + all_points = torch.stack(all_points, dim=0) + all_scores = torch.stack(all_scores, dim=0) + all_descs = torch.stack(all_descs, dim=0) + pl_associativity = torch.stack(pl_associativity, dim=0) + else: + # Lines are independent + all_points = torch.cat([lines.reshape(b_size, -1, 2), + pred['keypoints']], dim=1) + n_pts = all_points.shape[1] + num_lines = lines.shape[1] + num_true_junctions = [num_lines * 2] * b_size + all_scores = torch.cat([ + torch.repeat_interleave(line_scores, 2, dim=1), + pred['keypoint_scores']], dim=1) + pred['line_descriptors'] = self.endpoints_pooling( + lines, pred['all_descriptors'], (h, w)) + all_descs = torch.cat([ + pred['line_descriptors'].reshape(b_size, self.conf.sp_params.descriptor_dim, -1), + pred['descriptors']], dim=2) + pl_associativity = torch.eye( + n_pts, dtype=torch.bool, + device=device)[None].repeat(b_size, 1, 1) + lines_junc_idx = torch.arange( + num_lines * 2, device=device).reshape(1, -1, 2).repeat(b_size, 1, 1) + + del pred['all_descriptors'] # Remove dense descriptors to save memory + torch.cuda.empty_cache() + + return {'keypoints': all_points, + 'keypoint_scores': all_scores, + 'descriptors': all_descs, + 'pl_associativity': pl_associativity, + 'num_junctions': torch.tensor(num_true_junctions), + 'lines': lines, + 'orig_lines': orig_lines, + 'lines_junc_idx': lines_junc_idx, + 'line_scores': line_scores, + # 'valid_lines': valid_lines, + } + + @staticmethod + def endpoints_pooling(segs, all_descriptors, img_shape): + assert segs.ndim == 4 and segs.shape[-2:] == (2, 2) + filter_shape = all_descriptors.shape[-2:] + scale_x = filter_shape[1] / img_shape[1] + scale_y = filter_shape[0] / img_shape[0] + + scaled_segs = torch.round(segs * torch.tensor([scale_x, scale_y]).to(segs)).long() + scaled_segs[..., 0] = torch.clip(scaled_segs[..., 0], 0, filter_shape[1] - 1) + scaled_segs[..., 1] = torch.clip(scaled_segs[..., 1], 0, filter_shape[0] - 1) + line_descriptors = [all_descriptors[None, b, ..., torch.squeeze(b_segs[..., 1]), torch.squeeze(b_segs[..., 0])] + for b, b_segs in enumerate(scaled_segs)] + line_descriptors = torch.cat(line_descriptors) + return line_descriptors # Shape (1, 256, 308, 2) + + def loss(self, pred, data): + raise NotImplementedError + + def metrics(self, pred, data): + return {} diff --git a/predictor/predict.py b/predictor/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..8245dd19d043f898a8b48e5a2eae3cd9839d16a2 --- /dev/null +++ b/predictor/predict.py @@ -0,0 +1,131 @@ +import torch +import random +import numpy as np +import os +import os.path as osp +import glob +from tqdm import tqdm + +from scalelsd.base import setup_logger, MetricLogger, show, WireframeGraph + +from scalelsd.ssl.datasets import dataset_util +from scalelsd.ssl.models.detector import ScaleLSD +from scalelsd.ssl.misc.train_utils import load_scalelsd_model + +from torch.utils.data import DataLoader +import torch.utils.data.dataloader as torch_loader + +from pathlib import Path +import argparse, yaml, logging, time, datetime, cv2, copy, sys, json +from easydict import EasyDict +import accelerate +from accelerate import load_checkpoint_and_dispatch +import matplotlib +import matplotlib.pyplot as plt + +def parse_args(): + aparser = argparse.ArgumentParser() + aparser.add_argument('-c', '--ckpt', default='models/scalelsd-vitbase-v1-train-sa1b.pt', type=str, help='the path for loading checkpoints') + aparser.add_argument('-t','--threshold', default=10,type=float) + aparser.add_argument('-i', '--img', required=True, type=str) + aparser.add_argument('--width', default=512, type=int) + aparser.add_argument('--height', default=512,type=int) + aparser.add_argument('--whitebg', default=0.0, type=float) + aparser.add_argument('--saveto', default=None, type=str,) + aparser.add_argument('-e','--ext', default='pdf', type=str, choices=['pdf','png','json','txt']) + aparser.add_argument('--device', default='cuda', type=str, choices=['cuda','cpu','mps']) + aparser.add_argument('--disable-show', default=False, action='store_true') + aparser.add_argument('--draw-junctions-only', default=False, action='store_true') + aparser.add_argument('--use_lsd', default=False, action='store_true') + aparser.add_argument('--use_nms', default=False, action='store_true') + + ScaleLSD.cli(aparser) + + args = aparser.parse_args() + + ScaleLSD.configure(args) + + return args + + +def main(): + args = parse_args() + + model = load_scalelsd_model(args.ckpt, device=args.device) + + # Set up output directory and painter + if args.saveto is None: + print('No output directory specified, saving outputs to folder: temp_output/ScaleLSD') + args.saveto = 'temp_output/ScaleLSD' + os.makedirs(args.saveto,exist_ok=True) + + show.painters.HAWPainter.confidence_threshold = args.threshold + # show.painters.HAWPainter.line_width = 2 + # show.painters.HAWPainter.marker_size = 4 + show.Canvas.show = not args.disable_show + if args.whitebg > 0.0: + show.Canvas.white_overlay = args.whitebg + painter = show.painters.HAWPainter() + edge_color = 'orange' # 'midnightblue' + vertex_color = 'Cyan' # 'deeppink' + + # Prepare images + all_images = [] + if os.path.isfile(args.img) and args.img.endswith(('.jpg', '.png')): + all_images.append(args.img) + elif os.path.isdir(args.img): + for file in os.listdir(args.img): + if file.endswith(('.jpg', '.png')): + fname = os.path.join(args.img, file) + all_images.append(fname) + all_images = sorted(all_images) + else: + raise ValueError('Input must be a file or a directory containing images.') + + # Inference + for fname in tqdm(all_images): + pname = Path(fname) + image = cv2.imread(fname,0) + + # for resize input, default shape is [512, 512] + ori_shape = image.shape[:2] + image_cp = copy.deepcopy(image) + image_ = cv2.resize(image_cp, (args.width, args.height)) + image_ = torch.from_numpy(image_).float()/255.0 + image_ = image_[None,None].to(args.device) + + meta = { + 'width': ori_shape[1], + 'height':ori_shape[0], + 'filename': '', + 'use_lsd': args.use_lsd, + 'use_nms': args.use_nms, + } + + with torch.no_grad(): + outputs, _ = model(image_, meta) + outputs = outputs[0] + + + if args.saveto is not None: + + if args.ext in ['png', 'pdf']: + fig_file = osp.join(args.saveto, pname.with_suffix('.'+args.ext).name) + with show.image_canvas(fname, fig_file=fig_file) as ax: + if args.draw_junctions_only: + painter.draw_junctions(ax,outputs) + else: + # painter.draw_wireframe(ax,outputs) + painter.draw_wireframe(ax,outputs, edge_color=edge_color, vertex_color=vertex_color) + elif args.ext == 'json': + indices = WireframeGraph.xyxy2indices(outputs['juncs_pred'],outputs['lines_pred']) + wireframe = WireframeGraph(outputs['juncs_pred'], outputs['juncs_score'], indices, outputs['lines_score'], outputs['width'], outputs['height']) + outpath = osp.join(args.saveto, pname.with_suffix('.json').name) + with open(outpath,'w') as f: + json.dump(wireframe.jsonize(),f) + else: + raise ValueError('Unsupported extension: {} is not in [png, pdf, json]'.format(args.ext)) + + +if __name__ == "__main__": + main() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..0b7a4c8f884b4c7e5e23d17cef5b9ec283a01acf --- /dev/null +++ b/requirements.txt @@ -0,0 +1,21 @@ + +opencv-python +cython +matplotlib +yacs +scikit-image +tqdm +python-json-logger +h5py +shapely +pycolmap +seaborn +kornia +easydict +pynvml +timm +einops==0.7.0 +numpy==1.26.4 +gradio +pydantic==2.10.6 +pytlsd@git+https://github.com/iago-suarez/pytlsd.git@4180ab8 diff --git a/scalelsd/.gitignore b/scalelsd/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..2fc25ad384f6e2f539a1f7bfc09e0cac7f674896 --- /dev/null +++ b/scalelsd/.gitignore @@ -0,0 +1,10 @@ +__pycache__/ +*/__pycache__/ +**/__pycache__/ + +data-ssl +exp +exp-ssl +temp_output +third_party +./models \ No newline at end of file diff --git a/scalelsd/__init__.py b/scalelsd/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..71c007af3673c2b2a1cedcf56b0c4812bc3f2492 --- /dev/null +++ b/scalelsd/__init__.py @@ -0,0 +1,2 @@ +from . import base +from . import ssl \ No newline at end of file diff --git a/scalelsd/base/__init__.py b/scalelsd/base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6ba695ab8c98a07280c99cf60ed7e4023a999195 --- /dev/null +++ b/scalelsd/base/__init__.py @@ -0,0 +1,13 @@ +from .csrc import _C +from . import utils +from .utils.logger import setup_logger +from .utils.metric_logger import MetricLogger +from .wireframe import WireframeGraph + +__all__ = [ + "_C", + "utils", + "setup_logger", + "MetricLogger", + "WireframeGraph", +] \ No newline at end of file diff --git a/scalelsd/base/csrc/__init__.py b/scalelsd/base/csrc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..45eee912d74e7cd345337d53aba7fc5b06f9cc97 --- /dev/null +++ b/scalelsd/base/csrc/__init__.py @@ -0,0 +1,19 @@ +from torch.utils.cpp_extension import load +import glob +import os.path as osp + +__this__ = osp.dirname(__file__) + +try: + _C = load(name='_C',sources=[ + osp.join(__this__,'binding.cpp'), + osp.join(__this__,'linesegment.cu'), + ] + ) +except: + _C = None + +_C = load(name='_C', sources=[osp.join(__this__,'binding.cpp'), osp.join(__this__,'linesegment.cu')]) +__all__ = ["_C"] + +#_C = load(name='base._C', sources=['lltm_cuda.cpp', 'lltm_cuda_kernel.cu']) diff --git a/scalelsd/base/csrc/binding.cpp b/scalelsd/base/csrc/binding.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1b25ef2502d6b9be633be281bf14980c2a43064e --- /dev/null +++ b/scalelsd/base/csrc/binding.cpp @@ -0,0 +1,5 @@ +#include "linesegment.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("encodels", &encodels, "Encoding line segments to maps"); +} \ No newline at end of file diff --git a/scalelsd/base/csrc/linesegment.cu b/scalelsd/base/csrc/linesegment.cu new file mode 100644 index 0000000000000000000000000000000000000000..cd6022d02619e15e70c7750269dc62ad401791f8 --- /dev/null +++ b/scalelsd/base/csrc/linesegment.cu @@ -0,0 +1,139 @@ +#include +#include + +// #include +// #include +#include +#include + +#include +#include + +int const CUDA_NUM_THREADS = 1024; + +inline int CUDA_GET_BLOCKS(const int N) { + return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; +} + +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + + +__global__ void encode_kernel(const int nthreads, const float* lines, + const int input_height, const int input_width, const int num, + const int height, const int width, float* map, + bool* label, float* tmap) +{ + CUDA_1D_KERNEL_LOOP(index, nthreads){ + int w = index % width; + int h = (index / width) % height; + int x_index = h*width + w; + int y_index = height*width + h*width + w; + int ux_index = 2*height*width + h*width + w; + int uy_index = 3*height*width + h*width + w; + int vx_index = 4*height*width + h*width + w; + int vy_index = 5*height*width + h*width + w; + int label_index = h*width + w; + + float px = (float) w; + float py = (float) h; + float min_dis = 1e30; + int minp = -1; + bool flagp = true; + for(int i = 0; i < num; ++i) { + float xs = (float)width /(float)input_width; + float ys = (float)height /(float)input_height; + float x1 = lines[4*i ]*xs; + float y1 = lines[4*i+1]*ys; + float x2 = lines[4*i+2]*xs; + float y2 = lines[4*i+3]*ys; + + float dx = x2 - x1; + float dy = y2 - y1; + float ux = x1 - px; + float uy = y1 - py; + float vx = x2 - px; + float vy = y2 - py; + float norm2 = dx*dx + dy*dy; + bool flag = false; + float t = ((px-x1)*dx + (py-y1)*dy)/(norm2+1e-6); + if (t<=1 && t>=0.0) + flag = true; + + t = t<0.0? 0.0:t; + t = t>1.0? 1.0:t; + + float ax = x1 + t*(x2-x1) - px; + float ay = y1 + t*(y2-y1) - py; + + float dis = ax*ax + ay*ay; + if (dis < min_dis) { + min_dis = dis; + map[x_index] = ax; + map[y_index] = ay; + float norm_u2 = ux*ux+uy*uy; + float norm_v2 = vx*vx+vy*vy; + + if (norm_u2 < norm_v2){ + map[ux_index] = ux; + map[uy_index] = uy; + map[vx_index] = vx; + map[vy_index] = vy; + } + else{ + map[ux_index] = vx; + map[uy_index] = vy; + map[vx_index] = ux; + map[vy_index] = uy; + } + + minp = i; + if (flag) + flagp = true; + else + flagp = false; + + tmap[index] = t; + } + } + // label[label_index+minp*height*width] = flagp; + + } +} + + +std::tuple lsencode_cuda( + const at::Tensor& lines, + const int input_height, + const int input_width, + const int height, + const int width, + const int num_lines) + +{ + auto map = at::zeros({6,height,width}, lines.options()); + auto tmap = at::zeros({1,height,width}, lines.options()); + auto label = at::zeros({1,height,width}, lines.options().dtype(at::kBool)); + auto nthreads = height*width; + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + float* map_data = map.data(); + float* tmap_data = tmap.data(); + bool* label_data = label.data(); + + encode_kernel<<>>( + nthreads, + lines.contiguous().data(), + input_height, input_width, + num_lines, + height, width, + map_data, + label_data, + tmap_data); + + // THCudaCheck(cudaGetLastError()); + + return std::make_tuple(map, label, tmap); +} \ No newline at end of file diff --git a/scalelsd/base/csrc/linesegment.h b/scalelsd/base/csrc/linesegment.h new file mode 100644 index 0000000000000000000000000000000000000000..b3e93237c96cd97ce1ebc564656d11d36dbf86fe --- /dev/null +++ b/scalelsd/base/csrc/linesegment.h @@ -0,0 +1,26 @@ +// #pragma once +#include + +std::tuple lsencode_cuda( + const at::Tensor& lines, + const int input_height, + const int input_width, + const int height, + const int width, + const int num_lines); + +std::tuple encodels( + const at::Tensor& lines, + const int input_height, + const int input_width, + const int height, + const int width, + const int num_lines) +{ + return lsencode_cuda(lines, + input_height, + input_width, + height, + width, + num_lines); +} \ No newline at end of file diff --git a/scalelsd/base/show/__init__.py b/scalelsd/base/show/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ae4a667b196cfb5d90a9b83a2bbd718cabc61403 --- /dev/null +++ b/scalelsd/base/show/__init__.py @@ -0,0 +1,3 @@ +from .canvas import Canvas, image_canvas, canvas +from .painters import HAWPainter +from .cli import cli, configure \ No newline at end of file diff --git a/scalelsd/base/show/canvas.py b/scalelsd/base/show/canvas.py new file mode 100644 index 0000000000000000000000000000000000000000..11dea0b643a52ac8f2b66af480ebccc423109b09 --- /dev/null +++ b/scalelsd/base/show/canvas.py @@ -0,0 +1,153 @@ +from contextlib import contextmanager +import logging +import os + +from matplotlib.pyplot import figimage, margins +import numpy as np +import cv2 + +try: + import matplotlib.pyplot as plt # pylint: disable=import-error + +except ModuleNotFoundError as err: + if err.name != 'matplotlib': + raise err + plt = None + + +LOG = logging.getLogger(__name__) + +class Canvas: + """Canvas for plotting. + All methods expose Axes objects. To get Figure objects, you can ask the axis + `ax.get_figure()`. + """ + + all_images_directory = None + all_images_count = 0 + show = False + image_width = 7.0 + image_height = None + blank_dpi = 200 + image_dpi_factor = 1.0 + image_min_dpi = 50.0 + out_file_extension = 'pdf' + white_overlay = False + + @classmethod + def generic_name(cls): + if cls.all_images_directory is None: + return None + os.makedirs(cls.all_images_directory, exist_ok=True) + + cls.all_images_count += 1 + return os.path.join(cls.all_images_directory, + '{:04}.{}'.format(cls.all_images_count, cls.out_file_extension)) + + @classmethod + @contextmanager + def blank(cls, fig_file=None, *, dpi=None, nomargin=False, **kwargs): + if plt is None: + raise Exception('please install matplotlib') + if fig_file is None: + fig_file = cls.generic_name() + + if dpi is None: + dpi = cls.blank_dpi + + if 'figsize' not in kwargs: + kwargs['figsize'] = (10, 6) + + if nomargin: + if 'gridspec_kw' not in kwargs: + kwargs['gridspec_kw'] = {} + kwargs['gridspec_kw']['wspace'] = 0 + kwargs['gridspec_kw']['hspace'] = 0 + kwargs['gridspec_kw']['left'] = 0.0 + kwargs['gridspec_kw']['right'] = 1.0 + kwargs['gridspec_kw']['top'] = 1.0 + kwargs['gridspec_kw']['bottom'] = 0.0 + + fig, ax = plt.subplots(dpi=dpi, **kwargs) + + yield ax + + fig.set_tight_layout(not margins) + if fig_file: + LOG.debug('writing image to %s', fig_file) + fig.savefig(fig_file) + + if cls.show: + plt.show() + plt.close(fig) + + + @classmethod + @contextmanager + def image(cls, image, fig_file=None, *, margin=None, **kwargs): + if plt is None: + raise Exception('please install matplotlib') + if fig_file is None: + fig_file = cls.generic_name() + + if isinstance(image, str): + image = cv2.imread(image)[...,::-1] + else: + image = np.asarray(image) + + if margin is None: + margin = [0.0, 0.0, 0.0, 0.0] + elif isinstance(margin, float): + margin = [margin, margin, margin, margin] + assert len(margin) == 4 + + if 'figsize' not in kwargs: + # compute figure size: use image ratio and take the drawable area + # into account that is left after subtracting margins. + image_ratio = image.shape[0] / image.shape[1] + image_area_ratio = (1.0 - margin[1] - margin[3]) / (1.0 - margin[0] - margin[2]) + if cls.image_width is not None: + kwargs['figsize'] = ( + cls.image_width, + cls.image_width * image_ratio / image_area_ratio + ) + elif cls.image_height: + kwargs['figsize'] = ( + cls.image_height * image_area_ratio / image_ratio, + cls.image_height + ) + + # dpi = max(cls.image_min_dpi, image.shape[1] / kwargs['figsize'][0] * cls.image_dpi_factor) + dpi = 200 + # import pdb; pdb.set_trace() + fig = plt.figure(dpi=dpi, **kwargs) + ax = plt.Axes(fig, [0.0 + margin[0], + 0.0 + margin[1], + 1.0 - margin[2], + 1.0 - margin[3]]) + + ax.set_axis_off() + ax.set_xlim(-0.5, image.shape[1] - 0.5) # imshow uses center-pixel-coordinates + ax.set_ylim(image.shape[0] - 0.5, -0.5) + fig.add_axes(ax) + ax.imshow(image) + if cls.white_overlay: + white_screen(ax, cls.white_overlay) + yield ax + + if fig_file: + LOG.debug('writing image to %s', fig_file) + fig.savefig(fig_file) + if cls.show: + plt.show() + import pdb;pdb.set_trace() + plt.close(fig) + +def white_screen(ax, alpha=0.9): + ax.add_patch( + plt.Rectangle((0, 0), 1, 1, transform=ax.transAxes, alpha=alpha, + facecolor='white') + ) + +canvas = Canvas.blank +image_canvas = Canvas.image \ No newline at end of file diff --git a/scalelsd/base/show/cli.py b/scalelsd/base/show/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..171407457dfde710c6f15bb3d188bec5c1213021 --- /dev/null +++ b/scalelsd/base/show/cli.py @@ -0,0 +1,24 @@ +# from hawp.config import defaults +import logging + +from .canvas import Canvas +from .painters import HAWPainter +import matplotlib +LOG = logging.getLogger(__name__) + +def cli(parser): + group = parser.add_argument_group('show') + + assert not Canvas.show + group.add_argument('--show', default=False,action='store_true', + help='show every plot, i.e., call matplotlib show()') + + group.add_argument('--edge-threshold', default=None, type=float, + help='show the wireframe edges whose confidences are greater than [edge_threshold]') + group.add_argument('--out-ext', default='png', type=str, + help='save the plot in specific format') +def configure(args): + Canvas.show = args.show + Canvas.out_file_extension = args.out_ext + if args.edge_threshold is not None: + HAWPainter.confidence_threshold = args.edge_threshold \ No newline at end of file diff --git a/scalelsd/base/show/painters.py b/scalelsd/base/show/painters.py new file mode 100644 index 0000000000000000000000000000000000000000..bb5ca7bd89cf41b40ba695cd4da4e03558e24499 --- /dev/null +++ b/scalelsd/base/show/painters.py @@ -0,0 +1,80 @@ +import logging + +import numpy as np +import torch + + +try: + import matplotlib + import matplotlib.animation + import matplotlib.collections + import matplotlib.patches +except ImportError: + matplotlib = None + + +LOG = logging.getLogger(__name__) + + +class HAWPainter: + # line_width = None + # marker_size = None + line_width = 2 + marker_size = 4 + + confidence_threshold = 0.05 + + def __init__(self): + + if self.line_width is None: + self.line_width = 1 + + if self.marker_size is None: + self.marker_size = max(1, int(self.line_width * 0.5)) + + def draw_junctions(self, ax, wireframe, *, + edge_color = None, vertex_color = None): + if wireframe is None: + return + + if edge_color is None: + edge_color = 'b' + if vertex_color is None: + vertex_color = 'c' + + if 'lines_score' in wireframe.keys(): + line_segments = wireframe['lines_pred'][wireframe['lines_score']>self.confidence_threshold] + else: + line_segments = wireframe['lines_pred'] + + if isinstance(line_segments, torch.Tensor): + line_segments = line_segments.cpu().numpy() + + ax.plot(line_segments[:,0],line_segments[:,1],'.',color=vertex_color) + ax.plot(line_segments[:,2],line_segments[:,3],'.', + color=vertex_color) + def draw_wireframe(self, ax, wireframe, *, + edge_color = None, vertex_color = None): + if wireframe is None: + return + + if edge_color is None: + edge_color = 'b' + if vertex_color is None: + vertex_color = 'c' + + if 'lines_score' in wireframe.keys(): + line_segments = wireframe['lines_pred'][wireframe['lines_score']>self.confidence_threshold] + else: + line_segments = wireframe['lines_pred'] + + # import pdb;pdb.set_trace() + if isinstance(line_segments, torch.Tensor): + line_segments = line_segments.cpu().numpy() + + # import pdb;pdb.set_trace() + # line_segments = wireframe.line_segments(threshold=self.confidence_threshold) + # line_segments = line_segments.cpu().numpy() + ax.plot([line_segments[:,0],line_segments[:,2]],[line_segments[:,1],line_segments[:,3]],'-',color=edge_color,linewidth=self.line_width) + ax.plot(line_segments[:,0],line_segments[:,1],'.',color=vertex_color,markersize=self.marker_size) + ax.plot(line_segments[:,2],line_segments[:,3],'.',color=vertex_color,markersize=self.marker_size) diff --git a/scalelsd/base/utils/__init__.py b/scalelsd/base/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/scalelsd/base/utils/__init__.py @@ -0,0 +1 @@ + diff --git a/scalelsd/base/utils/logger.py b/scalelsd/base/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..2279763a5b465c2937b6458f9a65efb70265658b --- /dev/null +++ b/scalelsd/base/utils/logger.py @@ -0,0 +1,30 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import logging +import os +import sys +from pythonjsonlogger import jsonlogger + + +def setup_logger(name, save_dir, out_file='log.txt', json_format=False, rank=0): + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG) + + if json_format: + formatter = jsonlogger.JsonFormatter("%(asctime)s %(name)s %(levelname)s: %(message)s") + else: + formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") + + if rank == 0: + ch = logging.StreamHandler(stream=sys.stdout) + ch.setLevel(logging.DEBUG) + ch.setFormatter(formatter) + logger.addHandler(ch) + + if save_dir: + os.makedirs(save_dir, exist_ok=True) + fh = logging.FileHandler(os.path.join(save_dir, out_file)) + fh.setLevel(logging.DEBUG) + fh.setFormatter(formatter) + logger.addHandler(fh) + + return logger diff --git a/scalelsd/base/utils/metric_logger.py b/scalelsd/base/utils/metric_logger.py new file mode 100644 index 0000000000000000000000000000000000000000..a4db625494f6d67a23f61ded7f79a63accc49c20 --- /dev/null +++ b/scalelsd/base/utils/metric_logger.py @@ -0,0 +1,77 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +from collections import defaultdict +from collections import deque + +import torch + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20): + self.deque = deque(maxlen=window_size) + self.series = [] + self.total = 0.0 + self.count = 0 + + def update(self, value): + self.deque.append(value) + self.series.append(value) + self.count += 1 + self.total += value + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque)) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + keys = sorted(self.meters) + # for name, meter in self.meters.items(): + for name in keys: + meter = self.meters[name] + loss_str.append( + "{}: {:.4f} ({:.4f})".format(name, meter.median, meter.global_avg) + ) + return self.delimiter.join(loss_str) + + def tensorborad(self, iteration, writter, phase='train'): + for name, meter in self.meters.items(): + if 'loss' in name: + # writter.add_scalar('average/{}'.format(name), meter.avg, iteration) + writter.add_scalar('{}/global/{}'.format(phase,name), meter.global_avg, iteration) + # writter.add_scalar('median/{}'.format(name), meter.median, iteration) + diff --git a/scalelsd/base/wireframe.py b/scalelsd/base/wireframe.py new file mode 100644 index 0000000000000000000000000000000000000000..be96e41dc375ea0bdf58231b73fe71f7bd26aa9a --- /dev/null +++ b/scalelsd/base/wireframe.py @@ -0,0 +1,110 @@ +import copy +import math +import numpy as np +import torch +import json + +class WireframeGraph: + def __init__(self, + vertices: torch.Tensor, + v_confidences: torch.Tensor, + edges: torch.Tensor, + edge_weights: torch.Tensor, + frame_width: int, + frame_height: int): + self.vertices = vertices + self.v_confidences = v_confidences + self.edges = edges + self.weights = edge_weights + self.frame_width = frame_width + self.frame_height = frame_height + + @classmethod + def xyxy2indices(cls,junctions, lines): + # junctions: (N,2) + # lines: (M,4) + # return: (M,2) + dist1 = torch.norm(junctions[None,:,:]-lines[:,None,:2],dim=-1) + dist2 = torch.norm(junctions[None,:,:]-lines[:,None,2:],dim=-1) + idx1 = torch.argmin(dist1,dim=-1) + idx2 = torch.argmin(dist2,dim=-1) + return torch.stack((idx1,idx2),dim=-1) + @classmethod + def load_json(cls, fname): + with open(fname,'r') as f: + data = json.load(f) + + + vertices = torch.tensor(data['vertices']) + v_confidences = torch.tensor(data['vertices-score']) + edges = torch.tensor(data['edges']) + edge_weights = torch.tensor(data['edges-weights']) + height = data['height'] + width = data['width'] + + return WireframeGraph(vertices,v_confidences,edges,edge_weights,width,height) + + @property + def is_empty(self): + for key, val in self.__dict__.items(): + if val is None: + return True + return False + + @property + def num_vertices(self): + if self.is_empty: + return 0 + return self.vertices.shape[0] + + @property + def num_edges(self): + if self.is_empty: + return 0 + return self.edges.shape[0] + + + def line_segments(self, threshold = 0.05, device=None, to_np=False): + is_valid = self.weights>threshold + p1 = self.vertices[self.edges[is_valid,0]] + p2 = self.vertices[self.edges[is_valid,1]] + ps = self.weights[is_valid] + + lines = torch.cat((p1,p2,ps[:,None]),dim=-1) + if device is not None: + lines = lines.to(device) + if to_np: + lines = lines.cpu().numpy() + + return lines + # if device != self.device: + + def rescale(self, image_width, image_height): + scale_x = float(image_width)/float(self.frame_width) + scale_y = float(image_height)/float(self.frame_height) + + self.vertices[:,0] *= scale_x + self.vertices[:,1] *= scale_y + self.frame_width = image_width + self.frame_height = image_height + + def jsonize(self): + return { + 'vertices': self.vertices.cpu().tolist(), + 'vertices-score': self.v_confidences.cpu().tolist(), + 'edges': self.edges.cpu().tolist(), + 'edges-weights': self.weights.cpu().tolist(), + 'height': self.frame_height, + 'width': self.frame_width, + } + def __repr__(self) -> str: + return "WireframeGraph\n"+\ + "Vertices: {}\n".format(self.num_vertices)+\ + "Edges: {}\n".format(self.num_edges,) + \ + "Frame size (HxW): {}x{}".format(self.frame_height,self.frame_width) + +#graph = WireframeGraph() +if __name__ == "__main__": + graph = WireframeGraph.load_json('NeuS/public_data/bmvs_clock/hawp/000.json') + print(graph) + \ No newline at end of file diff --git a/scalelsd/encoder/__init__.py b/scalelsd/encoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f399ca54a11f080f9ac7bdab52e89d6bf738bdf0 --- /dev/null +++ b/scalelsd/encoder/__init__.py @@ -0,0 +1 @@ +from .hafm import HAFMencoder \ No newline at end of file diff --git a/scalelsd/encoder/hafm.py b/scalelsd/encoder/hafm.py new file mode 100644 index 0000000000000000000000000000000000000000..afc792dcef8488ec35061b0eed89e002c51bc1ba --- /dev/null +++ b/scalelsd/encoder/hafm.py @@ -0,0 +1,152 @@ +import torch +import numpy as np +from torch.utils.data.dataloader import default_collate + +from halt import _C + +class HAFMencoder(object): + def __init__(self, cfg): + self.dis_th = cfg.ENCODER.DIS_TH + self.ang_th = cfg.ENCODER.ANG_TH + self.num_static_pos_lines = cfg.ENCODER.NUM_STATIC_POS_LINES + self.num_static_neg_lines = cfg.ENCODER.NUM_STATIC_NEG_LINES + def __call__(self,annotations): + targets = [] + metas = [] + for ann in annotations: + t,m = self._process_per_image(ann) + targets.append(t) + metas.append(m) + + return default_collate(targets),metas + + def adjacent_matrix(self, n, edges, device): + mat = torch.zeros(n+1,n+1,dtype=torch.bool,device=device) + if edges.size(0)>0: + mat[edges[:,0], edges[:,1]] = 1 + mat[edges[:,1], edges[:,0]] = 1 + return mat + + def _process_per_image(self,ann): + junctions = ann['junctions'] + device = junctions.device + height, width = ann['height'], ann['width'] + jmap = torch.zeros((height,width),device=device) + joff = torch.zeros((2,height,width),device=device,dtype=torch.float32) + # junctions[:,0] = junctions[:,0].clamp(min=0,max=width-1) + # junctions[:,1] = junctions[:,1].clamp(min=0,max=height-1) + xint,yint = junctions[:,0].long(), junctions[:,1].long() + off_x = junctions[:,0] - xint.float()-0.5 + off_y = junctions[:,1] - yint.float()-0.5 + + jmap[yint,xint] = 1 + joff[0,yint,xint] = off_x + joff[1,yint,xint] = off_y + + edges_positive = ann['edges_positive'] + edges_negative = ann['edges_negative'] + + pos_mat = self.adjacent_matrix(junctions.size(0),edges_positive,device) + neg_mat = self.adjacent_matrix(junctions.size(0),edges_negative,device) + lines = torch.cat((junctions[edges_positive[:,0]], junctions[edges_positive[:,1]]),dim=-1) + lines_neg = torch.cat((junctions[edges_negative[:2000,0]],junctions[edges_negative[:2000,1]]),dim=-1) + lmap, _, _ = _C.encodels(lines,height,width,height,width,lines.size(0)) + + center_points = (lines[:,:2] + lines[:,2:])/2.0 + cmap = torch.zeros((height,width),device=device) + cxint, cyint = center_points[:,0].long(), center_points[:,1].long() + cmap[cyint,cxint] = 1 + + # yy,xx = torch.meshgrid(torch.arange(width,device=device),torch.arange(width,device=device)) + # gaussian = torch.exp(-((yy[:,:,None]-center_points[None,None,:,1])**2 + (xx[:,:,None]-center_points[None,None,:,0])**2)/(2*(2*2))) + # cmap = gaussian.max(dim=-1)[0] + + lpos = np.random.permutation(lines.cpu().numpy())[:self.num_static_pos_lines] + lneg = np.random.permutation(lines_neg.cpu().numpy())[:self.num_static_neg_lines] + # lpos = lines[torch.randperm(lines.size(0),device=device)][:self.num_static_pos_lines] + # lneg = lines_neg[torch.randperm(lines_neg.size(0),device=device)][:self.num_static_neg_lines] + lpos = torch.from_numpy(lpos).to(device) + lneg = torch.from_numpy(lneg).to(device) + + lpre = torch.cat((lpos,lneg),dim=0) + _swap = (torch.rand(lpre.size(0))>0.5).to(device) + lpre[_swap] = lpre[_swap][:,[2,3,0,1]] + lpre_label = torch.cat( + [ + torch.ones(lpos.size(0),device=device), + torch.zeros(lneg.size(0),device=device) + ]) + + meta = { + 'junc': junctions, + 'Lpos': pos_mat, + 'Lneg': neg_mat, + 'lpre': lpre, + 'lpre_label': lpre_label, + 'lines': lines, + } + + + dismap = torch.sqrt(lmap[0]**2+lmap[1]**2)[None] + def _normalize(inp): + mag = torch.sqrt(inp[0]*inp[0]+inp[1]*inp[1]) + return inp/(mag+1e-6) + md_map = _normalize(lmap[:2]) + st_map = _normalize(lmap[2:4]) + ed_map = _normalize(lmap[4:]) + st_map = lmap[2:4] + ed_map = lmap[4:] + + md_ = md_map.reshape(2,-1).t() + st_ = st_map.reshape(2,-1).t() + ed_ = ed_map.reshape(2,-1).t() + Rt = torch.cat( + (torch.cat((md_[:,None,None,0],md_[:,None,None,1]),dim=2), + torch.cat((-md_[:,None,None,1], md_[:,None,None,0]),dim=2)),dim=1) + R = torch.cat( + (torch.cat((md_[:,None,None,0], -md_[:,None,None,1]),dim=2), + torch.cat((md_[:,None,None,1], md_[:,None,None,0]),dim=2)),dim=1) + + Rtst_ = torch.matmul(Rt, st_[:,:,None]).squeeze(-1).t() + Rted_ = torch.matmul(Rt, ed_[:,:,None]).squeeze(-1).t() + swap_mask = (Rtst_[1]<0)*(Rted_[1]>0) + pos_ = Rtst_.clone() + neg_ = Rted_.clone() + temp = pos_[:,swap_mask] + pos_[:,swap_mask] = neg_[:,swap_mask] + neg_[:,swap_mask] = temp + + pos_[0] = pos_[0].clamp(min=1e-9) + pos_[1] = pos_[1].clamp(min=1e-9) + neg_[0] = neg_[0].clamp(min=1e-9) + neg_[1] = neg_[1].clamp(max=-1e-9) + + mask = (dismap.view(-1)<=self.dis_th).float() + + pos_map = pos_.reshape(-1,height,width) + neg_map = neg_.reshape(-1,height,width) + + md_angle = torch.atan2(md_map[1], md_map[0]) + pos_angle = torch.atan2(pos_map[1],pos_map[0]) + neg_angle = torch.atan2(neg_map[1],neg_map[0]) + + mask *= (pos_angle.reshape(-1)>self.ang_th*np.pi/2.0) + mask *= (neg_angle.reshape(-1)<-self.ang_th*np.pi/2.0) + + pos_angle_n = pos_angle/(np.pi/2) + neg_angle_n = -neg_angle/(np.pi/2) + md_angle_n = md_angle/(np.pi*2) + 0.5 + mask = mask.reshape(height,width) + + + hafm_ang = torch.cat((md_angle_n[None],pos_angle_n[None],neg_angle_n[None],),dim=0) + hafm_dis = dismap.clamp(max=self.dis_th)/self.dis_th + mask = mask[None] + target = {'jloc':jmap[None], + 'joff':joff, + 'cloc': cmap[None], + 'md': hafm_ang, + 'dis': hafm_dis, + 'mask': mask + } + return target, meta \ No newline at end of file diff --git a/scalelsd/ssl/backbones/__init__.py b/scalelsd/ssl/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..13aca1604db01f2239a6f53b911d79f43613d5de --- /dev/null +++ b/scalelsd/ssl/backbones/__init__.py @@ -0,0 +1 @@ +from .build import build_backbone \ No newline at end of file diff --git a/scalelsd/ssl/backbones/build.py b/scalelsd/ssl/backbones/build.py new file mode 100644 index 0000000000000000000000000000000000000000..9f1bed422626dc29a8e0a6b56e954202281f3e17 --- /dev/null +++ b/scalelsd/ssl/backbones/build.py @@ -0,0 +1,28 @@ +from .dpt.models import DPTFieldModel + +def build_dpt( + basemodel = "vitb_rn50_384", + features=256, + readout = "project", + channels_last = False, + use_bn = True, + enable_attention_hooks = False, + head_size = [[3],[1],[1],[2],[2]], + use_layer_scale = False, + **kwargs): + + model = DPTFieldModel( + features=features, + backbone=basemodel, + readout=readout, + channels_last=channels_last, + use_bn=use_bn, + enable_attention_hooks=enable_attention_hooks, + head_size=head_size, + use_layer_scale=use_layer_scale + ) + + return model + +def build_backbone(**kwargs): + return build_dpt(**kwargs) diff --git a/scalelsd/ssl/backbones/dpt/__init__.py b/scalelsd/ssl/backbones/dpt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scalelsd/ssl/backbones/dpt/base_model.py b/scalelsd/ssl/backbones/dpt/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5c2e0e93b0495f48a3405546b6fe1969be3480a2 --- /dev/null +++ b/scalelsd/ssl/backbones/dpt/base_model.py @@ -0,0 +1,16 @@ +import torch + + +class BaseModel(torch.nn.Module): + def load(self, path): + """Load model from file. + + Args: + path (str): file path + """ + parameters = torch.load(path, map_location=torch.device("cpu")) + + if "optimizer" in parameters: + parameters = parameters["model"] + + self.load_state_dict(parameters) diff --git a/scalelsd/ssl/backbones/dpt/blocks.py b/scalelsd/ssl/backbones/dpt/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..e03f29144f8af59666eb2e065b5010cb403884a4 --- /dev/null +++ b/scalelsd/ssl/backbones/dpt/blocks.py @@ -0,0 +1,388 @@ +import torch +import torch.nn as nn + +from .vit import ( + _make_pretrained_vitb_rn50_384, + _make_pretrained_vitl16_384, + _make_pretrained_vitb16_384, + forward_vit, +) + + +def _make_encoder( + backbone, + features, + use_pretrained, + groups=1, + expand=False, + exportable=True, + hooks=None, + use_vit_only=False, + use_readout="ignore", + enable_attention_hooks=False, + use_layer_scale=False, +): + if backbone == "vitl16_384": + pretrained = _make_pretrained_vitl16_384( + use_pretrained, + hooks=hooks, + use_readout=use_readout, + enable_attention_hooks=enable_attention_hooks, + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # ViT-L/16 - 85.0% Top1 (backbone) + elif backbone == "vitb_rn50_384": + pretrained = _make_pretrained_vitb_rn50_384( + use_pretrained, + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + enable_attention_hooks=enable_attention_hooks, + use_layer_scale=use_layer_scale, + ) + scratch = _make_scratch( + [256, 512, 768, 768], features, groups=groups, expand=expand + ) # ViT-H/16 - 85.0% Top1 (backbone) + elif backbone == "vitb16_384": + pretrained = _make_pretrained_vitb16_384( + use_pretrained, + hooks=hooks, + use_readout=use_readout, + enable_attention_hooks=enable_attention_hooks, + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # ViT-B/16 - 84.6% Top1 (backbone) + elif backbone == "resnext101_wsl": + pretrained = _make_pretrained_resnext101_wsl(use_pretrained) + scratch = _make_scratch( + [256, 512, 1024, 2048], features, groups=groups, expand=expand + ) # efficientnet_lite3 + else: + print(f"Backbone '{backbone}' not implemented") + assert False + + return pretrained, scratch + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + out_shape4 = out_shape + if expand == True: + out_shape1 = out_shape + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + out_shape4 = out_shape * 8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], + out_shape1, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], + out_shape2, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], + out_shape3, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer4_rn = nn.Conv2d( + in_shape[3], + out_shape4, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + + return scratch + + +def _make_resnet_backbone(resnet): + pretrained = nn.Module() + pretrained.layer1 = nn.Sequential( + resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 + ) + + pretrained.layer2 = resnet.layer2 + pretrained.layer3 = resnet.layer3 + pretrained.layer4 = resnet.layer4 + + return pretrained + + +def _make_pretrained_resnext101_wsl(use_pretrained): + resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") + return _make_resnet_backbone(resnet) + + +class Interpolate(nn.Module): + """Interpolation module.""" + + def __init__(self, scale_factor, mode, align_corners=False): + """Init. + + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, + scale_factor=self.scale_factor, + mode=self.mode, + align_corners=self.align_corners, + ) + + # x = self.interp(x, scale_factor=self.scale_factor) + # x = self.interp(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=True) + + return x + + +class ResidualConvUnit(nn.Module): + """Residual convolution module.""" + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + out = self.relu(x) + out = self.conv1(out) + out = self.relu(out) + out = self.conv2(out) + + return out + x + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block.""" + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.resConfUnit1 = ResidualConvUnit(features) + self.resConfUnit2 = ResidualConvUnit(features) + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + output += self.resConfUnit1(xs[1]) + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=True + ) + + return output + + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module.""" + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups = 1 + + self.conv1 = nn.Conv2d( + features, + features, + kernel_size=3, + stride=1, + padding=1, + bias=not self.bn, + groups=self.groups, + ) + + self.conv2 = nn.Conv2d( + features, + features, + kernel_size=3, + stride=1, + padding=1, + bias=not self.bn, + groups=self.groups, + ) + + if self.bn == True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn == True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn == True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + # return out + x + + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block.""" + + def __init__( + self, + features, + activation, + deconv=False, + bn=False, + expand=False, + align_corners=True, + ): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups = 1 + + self.expand = expand + out_features = features + if self.expand == True: + out_features = features // 2 + + self.out_conv = nn.Conv2d( + features, + out_features, + kernel_size=1, + stride=1, + padding=0, + bias=True, + groups=1, + ) + + self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + # output += res + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output diff --git a/scalelsd/ssl/backbones/dpt/midas_net.py b/scalelsd/ssl/backbones/dpt/midas_net.py new file mode 100644 index 0000000000000000000000000000000000000000..34d6d7e77b464e7df45b7ab45174a7413d8fbc89 --- /dev/null +++ b/scalelsd/ssl/backbones/dpt/midas_net.py @@ -0,0 +1,77 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, Interpolate, _make_encoder + + +class MidasNet_large(BaseModel): + """Network for monocular depth estimation.""" + + def __init__(self, path=None, features=256, non_negative=True): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet_large, self).__init__() + + use_pretrained = False if path is None else True + + self.pretrained, self.scratch = _make_encoder( + backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained + ) + + self.scratch.refinenet4 = FeatureFusionBlock(features) + self.scratch.refinenet3 = FeatureFusionBlock(features) + self.scratch.refinenet2 = FeatureFusionBlock(features) + self.scratch.refinenet1 = FeatureFusionBlock(features) + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + ) + + if path: + self.load(path) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) diff --git a/scalelsd/ssl/backbones/dpt/models.py b/scalelsd/ssl/backbones/dpt/models.py new file mode 100644 index 0000000000000000000000000000000000000000..f18d780cce763b4389ec4ac5869933ac90a886ac --- /dev/null +++ b/scalelsd/ssl/backbones/dpt/models.py @@ -0,0 +1,115 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .base_model import BaseModel +from .blocks import ( + FeatureFusionBlock, + FeatureFusionBlock_custom, + Interpolate, + _make_encoder, + forward_vit, +) +from ..multi_task_head import MultitaskHead + + +def _make_fusion_block(features, use_bn): + return FeatureFusionBlock_custom( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + ) + + +class DPT(BaseModel): + def __init__( + self, + head, + features=256, + backbone="vitb_rn50_384", + readout="project", + channels_last=False, + use_bn=False, + enable_attention_hooks=False, + use_layer_scale=False, + ): + + super(DPT, self).__init__() + + self.channels_last = channels_last + + hooks = { + "vitb_rn50_384": [0, 1, 8, 11], + "vitb16_384": [2, 5, 8, 11], + "vitl16_384": [5, 11, 17, 23], + } + + # Instantiate backbone and reassemble blocks + self.pretrained, self.scratch = _make_encoder( + backbone, + features, + False, # Set to true of you want to train from scratch, uses ImageNet weights + groups=1, + expand=False, + exportable=False, + hooks=hooks[backbone], + use_readout=readout, + enable_attention_hooks=enable_attention_hooks, + use_layer_scale=use_layer_scale, + ) + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + self.scratch.output_conv = head + + def forward(self, x): + if self.channels_last == True: + x.contiguous(memory_format=torch.channels_last) + + layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return out + +class DPTFieldModel(DPT): + def __init__(self, path=None, non_negative=True, head_size=[[3],[1],[1],[2],[2]], **kwargs): + features = kwargs["features"] if "features" in kwargs else 256 + + kwargs["use_bn"] = True + + num_class = sum(sum(head_size,[])) + head = nn.Sequential( + nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1), + # nn.BatchNorm2d(features//2), + nn.ReLU(True), + MultitaskHead(features//2, num_class, head_size=head_size), + ) + + super().__init__(head, **kwargs) + + self.stride = 2 + + def forward(self, x): + if x.shape[1] == 1: + x = torch.cat([x,x,x], dim=1) + + out = super().forward(x) + return out, None + diff --git a/scalelsd/ssl/backbones/dpt/transforms.py b/scalelsd/ssl/backbones/dpt/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..399adbcdad096ae3fb8a190ecd3ec5483a897251 --- /dev/null +++ b/scalelsd/ssl/backbones/dpt/transforms.py @@ -0,0 +1,231 @@ +import numpy as np +import cv2 +import math + + +def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): + """Rezise the sample to ensure the given size. Keeps aspect ratio. + + Args: + sample (dict): sample + size (tuple): image size + + Returns: + tuple: new size + """ + shape = list(sample["disparity"].shape) + + if shape[0] >= size[0] and shape[1] >= size[1]: + return sample + + scale = [0, 0] + scale[0] = size[0] / shape[0] + scale[1] = size[1] / shape[1] + + scale = max(scale) + + shape[0] = math.ceil(scale * shape[0]) + shape[1] = math.ceil(scale * shape[1]) + + # resize + sample["image"] = cv2.resize( + sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method + ) + + sample["disparity"] = cv2.resize( + sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST + ) + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + tuple(shape[::-1]), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return tuple(shape) + + +class Resize(object): + """Resize sample to given size (width, height).""" + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size( + sample["image"].shape[1], sample["image"].shape[0] + ) + + # resize sample + sample["image"] = cv2.resize( + sample["image"], + (width, height), + interpolation=self.__image_interpolation_method, + ) + + if self.__resize_target: + if "disparity" in sample: + sample["disparity"] = cv2.resize( + sample["disparity"], + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if "depth" in sample: + sample["depth"] = cv2.resize( + sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST + ) + + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std.""" + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input.""" + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + if "disparity" in sample: + disparity = sample["disparity"].astype(np.float32) + sample["disparity"] = np.ascontiguousarray(disparity) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + return sample diff --git a/scalelsd/ssl/backbones/dpt/vit.py b/scalelsd/ssl/backbones/dpt/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..208324a401c2af33a6c11be530d54c97d50d6458 --- /dev/null +++ b/scalelsd/ssl/backbones/dpt/vit.py @@ -0,0 +1,586 @@ +import torch +import torch.nn as nn +import timm +import types +import math +import torch.nn.functional as F + + +activations = {} + + +def get_activation(name): + def hook(model, input, output): + activations[name] = output + + return hook + + +attention = {} + + +def get_attention(name): + def hook(module, input, output): + x = input[0] + B, N, C = x.shape + qkv = ( + module.qkv(x) + .reshape(B, N, 3, module.num_heads, C // module.num_heads) + .permute(2, 0, 3, 1, 4).contiguous() + ) + q, k, v = ( + qkv[0], + qkv[1], + qkv[2], + ) # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1).contiguous()) * module.scale + + attn = attn.softmax(dim=-1) # [:,:,1,1:] + attention[name] = attn + + return hook + + +def get_mean_attention_map(attn, token, shape): + attn = attn[:, :, token, 1:] + attn = attn.unflatten(2, torch.Size([shape[2] // 16, shape[3] // 16])).float() + attn = torch.nn.functional.interpolate( + attn, size=shape[2:], mode="bicubic", align_corners=False + ).squeeze(0) + + all_attn = torch.mean(attn, 0) + + return all_attn + + +class Slice(nn.Module): + def __init__(self, start_index=1): + super(Slice, self).__init__() + self.start_index = start_index + + def forward(self, x): + return x[:, self.start_index :] + + +class AddReadout(nn.Module): + def __init__(self, start_index=1): + super(AddReadout, self).__init__() + self.start_index = start_index + + def forward(self, x): + if self.start_index == 2: + readout = (x[:, 0] + x[:, 1]) / 2 + else: + readout = x[:, 0] + return x[:, self.start_index :] + readout.unsqueeze(1) + + +class ProjectReadout(nn.Module): + def __init__(self, in_features, start_index=1): + super(ProjectReadout, self).__init__() + self.start_index = start_index + + self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) + + def forward(self, x): + readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :]) + features = torch.cat((x[:, self.start_index :], readout), -1) + + return self.project(features) + + +class Transpose(nn.Module): + def __init__(self, dim0, dim1): + super(Transpose, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x): + x = x.transpose(self.dim0, self.dim1).contiguous() + return x + + +def forward_vit(pretrained, x): + b, c, h, w = x.shape + + glob = pretrained.model.forward_flex(x) + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + layer_4 = pretrained.activations["4"] + + layer_1 = pretrained.act_postprocess1[0:2](layer_1) + layer_2 = pretrained.act_postprocess2[0:2](layer_2) + layer_3 = pretrained.act_postprocess3[0:2](layer_3) + layer_4 = pretrained.act_postprocess4[0:2](layer_4) + + unflatten = nn.Sequential( + nn.Unflatten( + 2, + torch.Size( + [ + h // pretrained.model.patch_size[1], + w // pretrained.model.patch_size[0], + ] + ), + ) + ) + + if layer_1.ndim == 3: + layer_1 = unflatten(layer_1) + if layer_2.ndim == 3: + layer_2 = unflatten(layer_2) + if layer_3.ndim == 3: + layer_3 = unflatten(layer_3) + if layer_4.ndim == 3: + layer_4 = unflatten(layer_4) + + layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1) + layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2) + layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3) + layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4) + + return layer_1, layer_2, layer_3, layer_4 + + +def _resize_pos_embed(self, posemb, gs_h, gs_w): + posemb_tok, posemb_grid = ( + posemb[:, : self.start_index], + posemb[0, self.start_index :], + ) + + gs_old = int(math.sqrt(len(posemb_grid))) + + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) + + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + + return posemb + + +def forward_flex(self, x): + b, c, h, w = x.shape + + pos_embed = self._resize_pos_embed( + self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] + ) + + B = x.shape[0] + + if hasattr(self.patch_embed, "backbone"): + x = self.patch_embed.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + + x = self.patch_embed.proj(x).flatten(2).transpose(1, 2).contiguous() + + if getattr(self, "dist_token", None) is not None: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + dist_token = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + else: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + x = x + pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + + return x + + +def get_readout_oper(vit_features, features, use_readout, start_index=1): + if use_readout == "ignore": + readout_oper = [Slice(start_index)] * len(features) + elif use_readout == "add": + readout_oper = [AddReadout(start_index)] * len(features) + elif use_readout == "project": + readout_oper = [ + ProjectReadout(vit_features, start_index) for out_feat in features + ] + else: + assert ( + False + ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" + + return readout_oper + + +def _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + size=[384, 384], + hooks=[2, 5, 8, 11], + vit_features=768, + use_readout="ignore", + start_index=1, + enable_attention_hooks=False, +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + if enable_attention_hooks: + pretrained.model.blocks[hooks[0]].attn.register_forward_hook( + get_attention("attn_1") + ) + pretrained.model.blocks[hooks[1]].attn.register_forward_hook( + get_attention("attn_2") + ) + pretrained.model.blocks[hooks[2]].attn.register_forward_hook( + get_attention("attn_3") + ) + pretrained.model.blocks[hooks[3]].attn.register_forward_hook( + get_attention("attn_4") + ) + pretrained.attention = attention + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + # 32, 48, 136, 384 + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=[0, 1, 8, 11], + vit_features=768, + use_vit_only=False, + use_readout="ignore", + start_index=1, + enable_attention_hooks=False, + use_layer_scale=False, +): + pretrained = nn.Module() + + ### + if use_layer_scale: + from timm.models.vision_transformer import LayerScale + for i, block in enumerate (model.blocks) : + block.ls1 = LayerScale(vit_features) + block.ls2 = LayerScale(vit_features) + + pretrained.model = model + + if use_vit_only == True: + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + else: + pretrained.model.patch_embed.backbone.stages[0].register_forward_hook( + get_activation("1") + ) + pretrained.model.patch_embed.backbone.stages[1].register_forward_hook( + get_activation("2") + ) + + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + if enable_attention_hooks: + pretrained.model.blocks[2].attn.register_forward_hook(get_attention("attn_1")) + pretrained.model.blocks[5].attn.register_forward_hook(get_attention("attn_2")) + pretrained.model.blocks[8].attn.register_forward_hook(get_attention("attn_3")) + pretrained.model.blocks[11].attn.register_forward_hook(get_attention("attn_4")) + pretrained.attention = attention + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + if use_vit_only == True: + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + else: + pretrained.act_postprocess1 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + pretrained.act_postprocess2 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitb_rn50_384( + pretrained, + use_readout="ignore", + hooks=None, + use_vit_only=False, + enable_attention_hooks=False, + use_layer_scale=False, +): + model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) + + hooks = [0, 1, 8, 11] if hooks == None else hooks + return _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + enable_attention_hooks=enable_attention_hooks, + use_layer_scale=use_layer_scale, + ) + + +def _make_pretrained_vitl16_384( + pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False +): + model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[256, 512, 1024, 1024], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + enable_attention_hooks=enable_attention_hooks, + ) + + +def _make_pretrained_vitb16_384( + pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False +): + model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + hooks=hooks, + use_readout=use_readout, + enable_attention_hooks=enable_attention_hooks, + ) + + +def _make_pretrained_deitb16_384( + pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False +): + model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + hooks=hooks, + use_readout=use_readout, + enable_attention_hooks=enable_attention_hooks, + ) + + +def _make_pretrained_deitb16_distil_384( + pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False +): + model = timm.create_model( + "vit_deit_base_distilled_patch16_384", pretrained=pretrained + ) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + hooks=hooks, + use_readout=use_readout, + start_index=2, + enable_attention_hooks=enable_attention_hooks, + ) diff --git a/scalelsd/ssl/backbones/multi_task_head.py b/scalelsd/ssl/backbones/multi_task_head.py new file mode 100644 index 0000000000000000000000000000000000000000..f40a0f15b5cc72e3f140545b3482970b863734fd --- /dev/null +++ b/scalelsd/ssl/backbones/multi_task_head.py @@ -0,0 +1,52 @@ +import torch +import torch.nn as nn +class MultitaskHead(nn.Module): + def __init__(self, input_channels, num_class, head_size): + super(MultitaskHead, self).__init__() + + m = int(input_channels / 4) + heads = [] + for output_channels in sum(head_size, []): + heads.append( + nn.Sequential( + nn.Conv2d(input_channels, m, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(m, output_channels, kernel_size=1), + ) + ) + self.heads = nn.ModuleList(heads) + assert num_class == sum(sum(head_size, [])) + + def forward(self, x): + # import pdb;pdb.set_trace() + return torch.cat([head(x) for head in self.heads], dim=1) + + +class AngleDistanceHead(nn.Module): + def __init__(self, input_channels, num_class, head_size): + super(AngleDistanceHead, self).__init__() + + m = int(input_channels/4) + + heads = [] + for output_channels in sum(head_size, []): + if output_channels != 2: + heads.append( + nn.Sequential( + nn.Conv2d(input_channels, m, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(m, output_channels, kernel_size=1), + ) + ) + else: + heads.append( + nn.Sequential( + nn.Conv2d(input_channels, m, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + CosineSineLayer(m) + ) + ) + self.heads = nn.ModuleList(heads) + assert num_class == sum(sum(head_size, [])) + def forward(self, x): + return torch.cat([head(x) for head in self.heads], dim=1) \ No newline at end of file diff --git a/scalelsd/ssl/config/__init__.py b/scalelsd/ssl/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d19d721b1406422a9dc5e341c21d31c5984d9 --- /dev/null +++ b/scalelsd/ssl/config/__init__.py @@ -0,0 +1,2 @@ +from .project_config import Config +from .utils import * \ No newline at end of file diff --git a/scalelsd/ssl/config/dataset/hpatches_dataset.yaml b/scalelsd/ssl/config/dataset/hpatches_dataset.yaml new file mode 100644 index 0000000000000000000000000000000000000000..70933157ecc10ebbb797d838f17bd067fd885c8c --- /dev/null +++ b/scalelsd/ssl/config/dataset/hpatches_dataset.yaml @@ -0,0 +1,105 @@ +### General dataset parameters +dataset_name: "hpatches" +add_augmentation_to_all_splits: False +gray_scale: True +# Ground truth source ('official' or path to the exported h5 dataset.) +# gt_source_train: "" # Fill with your own export file +# gt_source_test: "" # Fill with your own export file +# Return type: (1) single (to train the detector only) +# or (2) paired_desc (to train the detector + descriptor) +return_type: "single" +random_seed: 0 + +### Descriptor training parameters +# Number of points extracted per line +max_num_samples: 10 +# Max number of training line points extracted in the whole image +max_pts: 1000 +# Min distance between two points on a line (in pixels) +min_dist_pts: 10 +# Small jittering of the sampled points during training +jittering: 0 + +alteration: "all" +max_side: 1200 + +### Data preprocessing configuration +preprocessing: + resize: [512, 512] + blur_size: 11 +augmentation: + random_scaling: + enable: True + range: [0.7, 1.5] + photometric: + enable: true + primitives: ['random_brightness', 'random_contrast', + 'additive_speckle_noise', 'additive_gaussian_noise', + 'additive_shade', 'motion_blur' ] + params: + random_brightness: {brightness: 0.2} + random_contrast: {contrast: [0.3, 1.5]} + additive_gaussian_noise: {stddev_range: [0, 10]} + additive_speckle_noise: {prob_range: [0, 0.0035]} + additive_shade: + transparency_range: [-0.5, 0.5] + kernel_size_range: [100, 150] + motion_blur: {max_kernel_size: 3} + random_order: True + homographic: + enable: true + params: + translation: true + rotation: true + scaling: true + perspective: true + scaling_amplitude: 0.2 + perspective_amplitude_x: 0.2 + perspective_amplitude_y: 0.2 + patch_ratio: 0.85 + max_angle: 1.57 + allow_artifacts: true + valid_border_margin: 3 + +## Homography adaptation configuration +homography_adaptation: + num_iter: 10 + valid_border_margin: 3 + min_counts: 3 + homographies: + translation: true + rotation: true + scaling: true + perspective: true + scaling_amplitude: 0.2 + perspective_amplitude_x: 0.2 + perspective_amplitude_y: 0.2 + allow_artifacts: true + patch_ratio: 0.85 + +data: + name: hpatches + dataset_dir: HPatches_sequences + alteration: all + max_side: 1200 + batch_size: 1 + num_workers: 4 +model: + name: deeplsd + tiny: False + sharpen: True + line_neighborhood: 5 + loss_weights: + df: 1. + angle: 1. + detect_lines: True + multiscale: False + scale_factors: [1., 1.5] + line_detection_params: + grad_nfa: True + merge: False + optimize: False + use_vps: False + optimize_vps: False + filtering: True + grad_thresh: 3 diff --git a/scalelsd/ssl/config/dataset/nyu_dataset.yaml b/scalelsd/ssl/config/dataset/nyu_dataset.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ceae2b4c7e86f2d85cd770ffcc52d14baf3e1152 --- /dev/null +++ b/scalelsd/ssl/config/dataset/nyu_dataset.yaml @@ -0,0 +1,77 @@ +### General dataset parameters +dataset_name: "nyu" +add_augmentation_to_all_splits: False +gray_scale: True +# Ground truth source ('official' or path to the exported h5 dataset.) +# gt_source_train: "" # Fill with your own export file +# gt_source_test: "" # Fill with your own export file +# Return type: (1) single (to train the detector only) +# or (2) paired_desc (to train the detector + descriptor) +return_type: "single" +random_seed: 0 + +val_size: 49 + +### Descriptor training parameters +# Number of points extracted per line +max_num_samples: 10 +# Max number of training line points extracted in the whole image +max_pts: 1000 +# Min distance between two points on a line (in pixels) +min_dist_pts: 10 +# Small jittering of the sampled points during training +jittering: 0 + +### Data preprocessing configuration +preprocessing: + resize: [512, 512] + blur_size: 11 +augmentation: + random_scaling: + enable: True + range: [0.7, 1.5] + photometric: + enable: true + primitives: ['random_brightness', 'random_contrast', + 'additive_speckle_noise', 'additive_gaussian_noise', + 'additive_shade', 'motion_blur' ] + params: + random_brightness: {brightness: 0.2} + random_contrast: {contrast: [0.3, 1.5]} + additive_gaussian_noise: {stddev_range: [0, 10]} + additive_speckle_noise: {prob_range: [0, 0.0035]} + additive_shade: + transparency_range: [-0.5, 0.5] + kernel_size_range: [100, 150] + motion_blur: {max_kernel_size: 3} + random_order: True + homographic: + enable: true + params: + translation: true + rotation: true + scaling: true + perspective: true + scaling_amplitude: 0.2 + perspective_amplitude_x: 0.2 + perspective_amplitude_y: 0.2 + patch_ratio: 0.85 + max_angle: 1.57 + allow_artifacts: true + valid_border_margin: 3 + +## Homography adaptation configuration +homography_adaptation: + num_iter: 10 + valid_border_margin: 3 + min_counts: 3 + homographies: + translation: true + rotation: true + scaling: true + perspective: true + scaling_amplitude: 0.2 + perspective_amplitude_x: 0.2 + perspective_amplitude_y: 0.2 + allow_artifacts: true + patch_ratio: 0.85 diff --git a/scalelsd/ssl/config/dataset/official_yorkurban_dataset.yaml b/scalelsd/ssl/config/dataset/official_yorkurban_dataset.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d4cfe27c50ac1f471fadec2806ec62ec40103a1c --- /dev/null +++ b/scalelsd/ssl/config/dataset/official_yorkurban_dataset.yaml @@ -0,0 +1,75 @@ +### General dataset parameters +dataset_name: "official_yorkurban" +add_augmentation_to_all_splits: False +gray_scale: True +# Ground truth source ('official' or path to the exported h5 dataset.) +# gt_source_train: "" # Fill with your own export file +# gt_source_test: "" # Fill with your own export file +# Return type: (1) single (to train the detector only) +# or (2) paired_desc (to train the detector + descriptor) +return_type: "single" +random_seed: 0 + +### Descriptor training parameters +# Number of points extracted per line +max_num_samples: 10 +# Max number of training line points extracted in the whole image +max_pts: 1000 +# Min distance between two points on a line (in pixels) +min_dist_pts: 10 +# Small jittering of the sampled points during training +jittering: 0 + +### Data preprocessing configuration +preprocessing: + resize: [512, 512] + blur_size: 11 +augmentation: + random_scaling: + enable: True + range: [0.7, 1.5] + photometric: + enable: true + primitives: ['random_brightness', 'random_contrast', + 'additive_speckle_noise', 'additive_gaussian_noise', + 'additive_shade', 'motion_blur' ] + params: + random_brightness: {brightness: 0.2} + random_contrast: {contrast: [0.3, 1.5]} + additive_gaussian_noise: {stddev_range: [0, 10]} + additive_speckle_noise: {prob_range: [0, 0.0035]} + additive_shade: + transparency_range: [-0.5, 0.5] + kernel_size_range: [100, 150] + motion_blur: {max_kernel_size: 3} + random_order: True + homographic: + enable: true + params: + translation: true + rotation: true + scaling: true + perspective: true + scaling_amplitude: 0.2 + perspective_amplitude_x: 0.2 + perspective_amplitude_y: 0.2 + patch_ratio: 0.85 + max_angle: 1.57 + allow_artifacts: true + valid_border_margin: 3 + +## Homography adaptation configuration +homography_adaptation: + num_iter: 10 + valid_border_margin: 3 + min_counts: 3 + homographies: + translation: true + rotation: true + scaling: true + perspective: true + scaling_amplitude: 0.2 + perspective_amplitude_x: 0.2 + perspective_amplitude_y: 0.2 + allow_artifacts: true + patch_ratio: 0.85 diff --git a/scalelsd/ssl/config/dataset/rdnim_dataset.yaml b/scalelsd/ssl/config/dataset/rdnim_dataset.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0524bb0ccfb0d4f130d045c0be9c03a5b8667c8c --- /dev/null +++ b/scalelsd/ssl/config/dataset/rdnim_dataset.yaml @@ -0,0 +1,77 @@ +### General dataset parameters +dataset_name: "rdnim" +add_augmentation_to_all_splits: False +gray_scale: True +# Ground truth source ('official' or path to the exported h5 dataset.) +# gt_source_train: "" # Fill with your own export file +# gt_source_test: "" # Fill with your own export file +# Return type: (1) single (to train the detector only) +# or (2) paired_desc (to train the detector + descriptor) +return_type: "single" +random_seed: 0 + +### Descriptor training parameters +# Number of points extracted per line +max_num_samples: 10 +# Max number of training line points extracted in the whole image +max_pts: 1000 +# Min distance between two points on a line (in pixels) +min_dist_pts: 10 +# Small jittering of the sampled points during training +jittering: 0 + +reference: "night" + +### Data preprocessing configuration +preprocessing: + resize: [512, 512] + blur_size: 11 +augmentation: + random_scaling: + enable: True + range: [0.7, 1.5] + photometric: + enable: true + primitives: ['random_brightness', 'random_contrast', + 'additive_speckle_noise', 'additive_gaussian_noise', + 'additive_shade', 'motion_blur' ] + params: + random_brightness: {brightness: 0.2} + random_contrast: {contrast: [0.3, 1.5]} + additive_gaussian_noise: {stddev_range: [0, 10]} + additive_speckle_noise: {prob_range: [0, 0.0035]} + additive_shade: + transparency_range: [-0.5, 0.5] + kernel_size_range: [100, 150] + motion_blur: {max_kernel_size: 3} + random_order: True + homographic: + enable: true + params: + translation: true + rotation: true + scaling: true + perspective: true + scaling_amplitude: 0.2 + perspective_amplitude_x: 0.2 + perspective_amplitude_y: 0.2 + patch_ratio: 0.85 + max_angle: 1.57 + allow_artifacts: true + valid_border_margin: 3 + +## Homography adaptation configuration +homography_adaptation: + num_iter: 10 + valid_border_margin: 3 + min_counts: 3 + homographies: + translation: true + rotation: true + scaling: true + perspective: true + scaling_amplitude: 0.2 + perspective_amplitude_x: 0.2 + perspective_amplitude_y: 0.2 + allow_artifacts: true + patch_ratio: 0.85 diff --git a/scalelsd/ssl/config/dataset/synthetic_dataset-1024.yaml b/scalelsd/ssl/config/dataset/synthetic_dataset-1024.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fd79820d57c60bb7356054ad91b51b7b2db57464 --- /dev/null +++ b/scalelsd/ssl/config/dataset/synthetic_dataset-1024.yaml @@ -0,0 +1,49 @@ +### General dataset parameters +dataset_name: "synthetic_shape" +primitives: "all" +add_augmentation_to_all_splits: True +test_augmentation_seed: 200 +# Shape generation configuration +generation: + # split_sizes: {'train': 20000, 'val': 2000, 'test': 400} + split_sizes: {'train': 2000, 'val': 2000, 'test': 400} + random_seed: 10 + image_size: [960, 1280] + min_len: 0.0985 + min_label_len: 0.099 + params: + generate_background: + min_kernel_size: 150 + max_kernel_size: 500 + min_rad_ratio: 0.02 + max_rad_ratio: 0.031 + draw_stripes: + transform_params: [0.1, 0.1] + draw_multiple_polygons: + kernel_boundaries: [50, 100] + +### Data preprocessing configuration. +preprocessing: + resize: [1024, 1024] + blur_size: 11 +augmentation: + photometric: + enable: True + primitives: 'all' + params: {} + random_order: True + homographic: + enable: True + params: + translation: true + rotation: true + scaling: true + perspective: true + scaling_amplitude: 0.2 + perspective_amplitude_x: 0.2 + perspective_amplitude_y: 0.2 + patch_ratio: 0.8 + max_angle: 1.57 + allow_artifacts: true + translation_overflow: 0.05 + valid_border_margin: 0 diff --git a/scalelsd/ssl/config/dataset/synthetic_dataset-2k.yaml b/scalelsd/ssl/config/dataset/synthetic_dataset-2k.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f760ae66077247c9e6c355decf0f277e425561b4 --- /dev/null +++ b/scalelsd/ssl/config/dataset/synthetic_dataset-2k.yaml @@ -0,0 +1,50 @@ +### General dataset parameters +dataset_name: "synthetic_shape" +primitives: "all" +add_augmentation_to_all_splits: True +test_augmentation_seed: 200 +alias: 2k +# Shape generation configuration +generation: + # split_sizes: {'train': 20000, 'val': 2000, 'test': 400} + split_sizes: {'train': 2000, 'val': 200, 'test': 400} + random_seed: 10 + image_size: [960, 1280] + min_len: 0.0985 + min_label_len: 0.099 + params: + generate_background: + min_kernel_size: 150 + max_kernel_size: 500 + min_rad_ratio: 0.02 + max_rad_ratio: 0.031 + draw_stripes: + transform_params: [0.1, 0.1] + draw_multiple_polygons: + kernel_boundaries: [50, 100] + +### Data preprocessing configuration. +preprocessing: + resize: [512, 512] + blur_size: 11 +augmentation: + photometric: + enable: True + primitives: 'all' + params: {} + random_order: True + homographic: + enable: True + params: + translation: true + rotation: true + scaling: true + perspective: true + scaling_amplitude: 0.2 + perspective_amplitude_x: 0.2 + perspective_amplitude_y: 0.2 + patch_ratio: 0.8 + max_angle: 1.57 + allow_artifacts: true + translation_overflow: 0.05 + valid_border_margin: 0 diff --git a/scalelsd/ssl/config/dataset/synthetic_dataset-4k.yaml b/scalelsd/ssl/config/dataset/synthetic_dataset-4k.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7d1dd99cc778b6211d4d3251fb0f12d8e4bf9daf --- /dev/null +++ b/scalelsd/ssl/config/dataset/synthetic_dataset-4k.yaml @@ -0,0 +1,50 @@ +### General dataset parameters +dataset_name: "synthetic_shape" +primitives: "all" +add_augmentation_to_all_splits: True +test_augmentation_seed: 200 +alias: 4k +# Shape generation configuration +generation: + # split_sizes: {'train': 20000, 'val': 2000, 'test': 400} + split_sizes: {'train': 4000, 'val': 2000, 'test': 400} + random_seed: 10 + image_size: [960, 1280] + min_len: 0.0985 + min_label_len: 0.099 + params: + generate_background: + min_kernel_size: 150 + max_kernel_size: 500 + min_rad_ratio: 0.02 + max_rad_ratio: 0.031 + draw_stripes: + transform_params: [0.1, 0.1] + draw_multiple_polygons: + kernel_boundaries: [50, 100] + +### Data preprocessing configuration. +preprocessing: + resize: [512, 512] + blur_size: 11 +augmentation: + photometric: + enable: True + primitives: 'all' + params: {} + random_order: True + homographic: + enable: True + params: + translation: true + rotation: true + scaling: true + perspective: true + scaling_amplitude: 0.2 + perspective_amplitude_x: 0.2 + perspective_amplitude_y: 0.2 + patch_ratio: 0.8 + max_angle: 1.57 + allow_artifacts: true + translation_overflow: 0.05 + valid_border_margin: 0 diff --git a/scalelsd/ssl/config/dataset/synthetic_dataset-large.yaml b/scalelsd/ssl/config/dataset/synthetic_dataset-large.yaml new file mode 100644 index 0000000000000000000000000000000000000000..844f44af29b335e0b4275d0feba792a77a522513 --- /dev/null +++ b/scalelsd/ssl/config/dataset/synthetic_dataset-large.yaml @@ -0,0 +1,50 @@ +### General dataset parameters +dataset_name: "synthetic_shape" +primitives: "all" +add_augmentation_to_all_splits: True +test_augmentation_seed: 200 +alias: "synthetic_shape_large" +# Shape generation configuration +generation: + split_sizes: {'train': 20000, 'val': 2000, 'test': 400} + # split_sizes: {'train': 2000, 'val': 2000, 'test': 400} + random_seed: 10 + image_size: [960, 1280] + min_len: 0.0985 + min_label_len: 0.099 + params: + generate_background: + min_kernel_size: 150 + max_kernel_size: 500 + min_rad_ratio: 0.02 + max_rad_ratio: 0.031 + draw_stripes: + transform_params: [0.1, 0.1] + draw_multiple_polygons: + kernel_boundaries: [50, 100] + +### Data preprocessing configuration. +preprocessing: + resize: [512, 512] + blur_size: 11 +augmentation: + photometric: + enable: True + primitives: 'all' + params: {} + random_order: True + homographic: + enable: True + params: + translation: true + rotation: true + scaling: true + perspective: true + scaling_amplitude: 0.2 + perspective_amplitude_x: 0.2 + perspective_amplitude_y: 0.2 + patch_ratio: 0.8 + max_angle: 1.57 + allow_artifacts: true + translation_overflow: 0.05 + valid_border_margin: 0 diff --git a/scalelsd/ssl/config/dataset/synthetic_dataset.yaml b/scalelsd/ssl/config/dataset/synthetic_dataset.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a91258f956a7ee68f3cd796005eb1d5c30c9ae83 --- /dev/null +++ b/scalelsd/ssl/config/dataset/synthetic_dataset.yaml @@ -0,0 +1,51 @@ +### General dataset parameters +dataset_name: "synthetic_shape" +primitives: "all" +add_augmentation_to_all_splits: True +test_augmentation_seed: 200 +# Shape generation configuration +generation: + # split_sizes: {'train': 20000, 'val': 2000, 'test': 400} + # split_sizes: {'train': 2000, 'val': 2000, 'test': 400} + split_sizes: {'train': 100, 'val': 100, 'test': 100} + random_seed: 10 + # image_size: [960, 1280] + image_size: [1024, 1024] + min_len: 0.0985 + min_label_len: 0.099 + params: + generate_background: + min_kernel_size: 150 + max_kernel_size: 500 + min_rad_ratio: 0.02 + max_rad_ratio: 0.031 + draw_stripes: + transform_params: [0.1, 0.1] + draw_multiple_polygons: + kernel_boundaries: [50, 100] + +### Data preprocessing configuration. +preprocessing: + resize: [512, 512] + blur_size: 11 +augmentation: + photometric: + enable: True + primitives: 'all' + params: {} + random_order: True + homographic: + enable: True + params: + translation: true + rotation: true + scaling: true + perspective: true + scaling_amplitude: 0.2 + perspective_amplitude_x: 0.2 + perspective_amplitude_y: 0.2 + patch_ratio: 0.8 + max_angle: 1.57 + allow_artifacts: true + translation_overflow: 0.05 + valid_border_margin: 0 diff --git a/scalelsd/ssl/config/dataset/wireframe_official_gt copy.yaml b/scalelsd/ssl/config/dataset/wireframe_official_gt copy.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f6a7a4194e23f3717bbc3b3a2bf707891f2e9abb --- /dev/null +++ b/scalelsd/ssl/config/dataset/wireframe_official_gt copy.yaml @@ -0,0 +1,86 @@ +dataset_name: "wireframe" +add_augmentation_to_all_splits: False +gray_scale: True +# return_type: "paired_desc" +random_seed: 0 +# Ground truth source (official or path to the epxorted h5 dataset.) +gt_source_train: "official" +gt_source_test: "official" +# Date preprocessing configuration. +preprocessing: + resize: [512, 512] + blur_size: 11 +augmentation: + random_scaling: + enable: True + range: [0.7, 1.5] + photometric: + enable: true + primitives: ['random_brightness', 'random_contrast', + 'additive_speckle_noise', 'additive_gaussian_noise', + 'additive_shade', 'motion_blur' ] + params: + random_brightness: {brightness: 0.2} + random_contrast: {contrast: [0.3, 1.5]} + additive_gaussian_noise: {stddev_range: [0, 10]} + additive_speckle_noise: {prob_range: [0, 0.0035]} + additive_shade: + transparency_range: [-0.5, 0.5] + kernel_size_range: [100, 150] + motion_blur: {max_kernel_size: 3} + random_order: True + homographic: + enable: true + params: + translation: true + rotation: true + scaling: true + perspective: true + scaling_amplitude: 0.2 + perspective_amplitude_x: 0.2 + perspective_amplitude_y: 0.2 + patch_ratio: 0.85 + max_angle: 1.57 + allow_artifacts: true + valid_border_margin: 3 +# The homography adaptation configuration +homography_adaptation: + num_iter: 100 + aggregation: 'sum' + mode: 'ver1' + valid_border_margin: 3 + min_counts: 30 + homographies: + translation: true + rotation: true + scaling: true + perspective: true + scaling_amplitude: 0.2 + perspective_amplitude_x: 0.2 + perspective_amplitude_y: 0.2 + allow_artifacts: true + patch_ratio: 0.85 +# Evaluation related config +evaluation: + repeatability: + # Initial random seed used to sample homographic augmentation + seed: 200 + # Parameter used to sample illumination change evaluation set. + photometric: + enable: False + # Parameter used to sample viewpoint change evaluation set. + homographic: + enable: True + num_samples: 2 + params: + translation: true + rotation: true + scaling: true + perspective: true + scaling_amplitude: 0.2 + perspective_amplitude_x: 0.2 + perspective_amplitude_y: 0.2 + patch_ratio: 0.85 + max_angle: 1.57 + allow_artifacts: true + valid_border_margin: 3 \ No newline at end of file diff --git a/scalelsd/ssl/config/dataset/wireframe_official_gt.yaml b/scalelsd/ssl/config/dataset/wireframe_official_gt.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f6a7a4194e23f3717bbc3b3a2bf707891f2e9abb --- /dev/null +++ b/scalelsd/ssl/config/dataset/wireframe_official_gt.yaml @@ -0,0 +1,86 @@ +dataset_name: "wireframe" +add_augmentation_to_all_splits: False +gray_scale: True +# return_type: "paired_desc" +random_seed: 0 +# Ground truth source (official or path to the epxorted h5 dataset.) +gt_source_train: "official" +gt_source_test: "official" +# Date preprocessing configuration. +preprocessing: + resize: [512, 512] + blur_size: 11 +augmentation: + random_scaling: + enable: True + range: [0.7, 1.5] + photometric: + enable: true + primitives: ['random_brightness', 'random_contrast', + 'additive_speckle_noise', 'additive_gaussian_noise', + 'additive_shade', 'motion_blur' ] + params: + random_brightness: {brightness: 0.2} + random_contrast: {contrast: [0.3, 1.5]} + additive_gaussian_noise: {stddev_range: [0, 10]} + additive_speckle_noise: {prob_range: [0, 0.0035]} + additive_shade: + transparency_range: [-0.5, 0.5] + kernel_size_range: [100, 150] + motion_blur: {max_kernel_size: 3} + random_order: True + homographic: + enable: true + params: + translation: true + rotation: true + scaling: true + perspective: true + scaling_amplitude: 0.2 + perspective_amplitude_x: 0.2 + perspective_amplitude_y: 0.2 + patch_ratio: 0.85 + max_angle: 1.57 + allow_artifacts: true + valid_border_margin: 3 +# The homography adaptation configuration +homography_adaptation: + num_iter: 100 + aggregation: 'sum' + mode: 'ver1' + valid_border_margin: 3 + min_counts: 30 + homographies: + translation: true + rotation: true + scaling: true + perspective: true + scaling_amplitude: 0.2 + perspective_amplitude_x: 0.2 + perspective_amplitude_y: 0.2 + allow_artifacts: true + patch_ratio: 0.85 +# Evaluation related config +evaluation: + repeatability: + # Initial random seed used to sample homographic augmentation + seed: 200 + # Parameter used to sample illumination change evaluation set. + photometric: + enable: False + # Parameter used to sample viewpoint change evaluation set. + homographic: + enable: True + num_samples: 2 + params: + translation: true + rotation: true + scaling: true + perspective: true + scaling_amplitude: 0.2 + perspective_amplitude_x: 0.2 + perspective_amplitude_y: 0.2 + patch_ratio: 0.85 + max_angle: 1.57 + allow_artifacts: true + valid_border_margin: 3 \ No newline at end of file diff --git a/scalelsd/ssl/config/dataset/yorkurban_dataset.yaml b/scalelsd/ssl/config/dataset/yorkurban_dataset.yaml new file mode 100644 index 0000000000000000000000000000000000000000..af75ec177e59492368486a761bd4f34427e15e20 --- /dev/null +++ b/scalelsd/ssl/config/dataset/yorkurban_dataset.yaml @@ -0,0 +1,99 @@ +### General dataset parameters +dataset_name: "yorkurban" +add_augmentation_to_all_splits: False +gray_scale: True +# Ground truth source ('official' or path to the exported h5 dataset.) +# gt_source_train: "" # Fill with your own export file +# gt_source_test: "" # Fill with your own export file +# Return type: (1) single (to train the detector only) +# or (2) paired_desc (to train the detector + descriptor) +return_type: "single" +random_seed: 0 + +### Descriptor training parameters +# Number of points extracted per line +max_num_samples: 10 +# Max number of training line points extracted in the whole image +max_pts: 1000 +# Min distance between two points on a line (in pixels) +min_dist_pts: 10 +# Small jittering of the sampled points during training +jittering: 0 + +### Data preprocessing configuration +preprocessing: + resize: [512, 512] + blur_size: 11 +augmentation: + random_scaling: + enable: True + range: [0.7, 1.5] + photometric: + enable: true + primitives: ['random_brightness', 'random_contrast', + 'additive_speckle_noise', 'additive_gaussian_noise', + 'additive_shade', 'motion_blur' ] + params: + random_brightness: {brightness: 0.2} + random_contrast: {contrast: [0.3, 1.5]} + additive_gaussian_noise: {stddev_range: [0, 10]} + additive_speckle_noise: {prob_range: [0, 0.0035]} + additive_shade: + transparency_range: [-0.5, 0.5] + kernel_size_range: [100, 150] + motion_blur: {max_kernel_size: 3} + random_order: True + homographic: + enable: true + params: + translation: true + rotation: true + scaling: true + perspective: true + scaling_amplitude: 0.2 + perspective_amplitude_x: 0.2 + perspective_amplitude_y: 0.2 + patch_ratio: 0.85 + max_angle: 1.57 + allow_artifacts: true + valid_border_margin: 3 + +## Homography adaptation configuration +homography_adaptation: + num_iter: 10 + valid_border_margin: 3 + min_counts: 3 + homographies: + translation: false + rotation: false + scaling: true + perspective: false + scaling_amplitude: -1 + perspective_amplitude_x: 0.2 + perspective_amplitude_y: 0.2 + allow_artifacts: true + patch_ratio: 0.85 +# Evaluation related config +evaluation: + repeatability: + # Initial random seed used to sample homographic augmentation + seed: 200 + # Parameter used to sample illumination change evaluation set. + photometric: + enable: False + # Parameter used to sample viewpoint change evaluation set. + homographic: + enable: True + num_samples: 2 + params: + translation: true + rotation: true + scaling: true + perspective: true + scaling_amplitude: -1 + perspective_amplitude_x: 0.2 + perspective_amplitude_y: 0.2 + patch_ratio: 0.85 + max_angle: 1.57 + allow_artifacts: true + valid_border_margin: 3 diff --git a/scalelsd/ssl/config/project_config.py b/scalelsd/ssl/config/project_config.py new file mode 100644 index 0000000000000000000000000000000000000000..1070824a584e38e52941aca2604fbc80cf7fca2f --- /dev/null +++ b/scalelsd/ssl/config/project_config.py @@ -0,0 +1,70 @@ +""" +Project configurations. +""" +import os + + +class Config(object): + """ Datasets and experiments folders for the whole project. """ + ##################### + ## Dataset setting ## + ##################### + default_dataroot = os.path.join( + os.path.dirname(__file__), + '..','..','..','data-ssl' + ) + default_dataroot = os.path.abspath(default_dataroot) + default_exproot = os.path.join( + os.path.dirname(__file__), + '..','..','..','exp-ssl' + ) + default_exproot = os.path.abspath(default_exproot) + + DATASET_ROOT = os.getenv("DATASET_ROOT", default_dataroot) # TODO: path to your datasets folder + if not os.path.exists(DATASET_ROOT): + os.makedirs(DATASET_ROOT, exist_ok=True) + + # Synthetic shape dataset + synthetic_dataroot = os.path.join(DATASET_ROOT, "synthetic_shapes") + synthetic_cache_path = os.path.join(DATASET_ROOT, "synthetic_shapes") + if not os.path.exists(synthetic_dataroot): + os.makedirs(synthetic_dataroot, exist_ok=True) + + EXPORT_ROOT = os.getenv("EXPORT_ROOT", default_dataroot) # TODO: path to your datasets folder + + # Exported predictions dataset + export_dataroot = os.path.join(EXPORT_ROOT, "export_datasets") + export_cache_path = os.path.join(EXPORT_ROOT, "export_datasets") + if not os.path.exists(export_dataroot): + os.makedirs(export_dataroot, exist_ok=True) + + # York Urban dataset + yorkurban_dataroot = os.path.join(DATASET_ROOT, "YorkUrbanDB") + yorkurban_cache_path = os.path.join(DATASET_ROOT, "YorkUrbanDB") + + # Wireframe dataset + wireframe_dataroot = os.path.join(DATASET_ROOT, "wireframe") + wireframe_cache_path = os.path.join(DATASET_ROOT, "wireframe") + + # Holicity dataset + holicity_dataroot = os.path.join(DATASET_ROOT, "Holicity") + holicity_cache_path = os.path.join(DATASET_ROOT, "Holicity") + + # Official York Urban dataset + official_yorkurban_dataroot = os.path.join(DATASET_ROOT, "off_YorkUrbanDB") + official_yorkurban_cache_path = os.path.join(DATASET_ROOT, "off_YorkUrbanDB") + + # NYU_depth_v2 + nyu_dataroot = os.path.join(DATASET_ROOT, "NYU_depth_v2") + nyu_dataroot_cache_path = os.path.join(DATASET_ROOT, "NYU_depth_v2") + + rdnim_dataroot = os.path.join(DATASET_ROOT, "RDNIM") + hpatches_dataroot = os.path.join(DATASET_ROOT, "HPatches_sequences") + + ######################## + ## Experiment Setting ## + ######################## + EXP_PATH = os.getenv("EXP_PATH", default_exproot) # TODO: path to your experiments folder + + if not os.path.exists(EXP_PATH): + os.makedirs(EXP_PATH, exist_ok=True) diff --git a/scalelsd/ssl/config/utils.py b/scalelsd/ssl/config/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0e5fba5079c0c44a2498127c6bc412d63f08d724 --- /dev/null +++ b/scalelsd/ssl/config/utils.py @@ -0,0 +1,50 @@ +import yaml +import os +from easydict import EasyDict + +def load_config(config_path): + """ Load configurations from a given yaml file. """ + # Check file exists + if not os.path.exists(config_path): + raise ValueError("[Error] The provided config path is not valid.") + + # Load the configuration + with open(config_path, "r") as f: + config = yaml.safe_load(f) + + # return EasyDict(config) + return config + +def update_config(path, model_cfg=None, dataset_cfg=None): + """ Update configuration file from the resume path. """ + # Check we need to update or completely override. + model_cfg = {} if model_cfg is None else model_cfg + dataset_cfg = {} if dataset_cfg is None else dataset_cfg + + # Load saved configs + with open(os.path.join(path, "model_cfg.yaml"), "r") as f: + model_cfg_saved = yaml.safe_load(f) + model_cfg.update(model_cfg_saved) + with open(os.path.join(path, "dataset_cfg.yaml"), "r") as f: + dataset_cfg_saved = yaml.safe_load(f) + dataset_cfg.update(dataset_cfg_saved) + + # Update the saved yaml file + if not model_cfg == model_cfg_saved: + with open(os.path.join(path, "model_cfg.yaml"), "w") as f: + yaml.dump(model_cfg, f) + if not dataset_cfg == dataset_cfg_saved: + with open(os.path.join(path, "dataset_cfg.yaml"), "w") as f: + yaml.dump(dataset_cfg, f) + + return model_cfg, dataset_cfg + +def record_config(model_cfg, dataset_cfg, output_path): + """ Record dataset config to the log path. """ + # Record model config + with open(os.path.join(output_path, "model_cfg.yaml"), "w") as f: + yaml.safe_dump(model_cfg, f) + + # Record dataset config + with open(os.path.join(output_path, "dataset_cfg.yaml"), "w") as f: + yaml.safe_dump(dataset_cfg, f) \ No newline at end of file diff --git a/scalelsd/ssl/datasets/__init__.py b/scalelsd/ssl/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scalelsd/ssl/datasets/dataset_eval.py b/scalelsd/ssl/datasets/dataset_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..2dfd2fb9940ae6480472007eb2d11a44e555582a --- /dev/null +++ b/scalelsd/ssl/datasets/dataset_eval.py @@ -0,0 +1,87 @@ +from pathlib import Path +import cv2 +import PIL +import numpy as np +import torch +import torch.utils +from torch.utils.data import DataLoader +from tqdm import tqdm +import glob + +from .transforms.homographic_transforms import sample_homography +from kornia.geometry import warp_perspective,transform_points + + +homography_params = { + 'translation': True, + 'rotation': True, + 'scaling': True, + 'perspective': True, + 'scaling_amplitude': 0.2, + 'perspective_amplitude_x': 0.2, + 'perspective_amplitude_y': 0.2, + 'patch_ratio': 0.85, + 'max_angle': 1.57, + 'allow_artifacts': True +} + +class Hybrid_Dataset(torch.utils.data.Dataset): + def __init__(self, datacfg=None, images_root=None, overwrite=False): + self.conf = datacfg + self.root = images_root + + # torch.manual_seed(self.conf.seed) + # np.random.seed(self.conf.seed) + + # # Extract images paths + # self.files = [Path(self.root)/img for img in Path(self.root).iterdir() + # if img.with_suffix('.png') or img.with_suffix('.jpg')] + self.files = glob.glob(f'{images_root}/*.png') + glob.glob(f'{images_root}/*.jpg') + self.files.sort() + + self.npz_files = [] if overwrite else glob.glob(f'{images_root}/*.npz') + + self.size = (512, 512) + + self.overwrite = overwrite + + if len(self.files) == 0: + raise ValueError(f'Could not find any images in the path of {self.root}. Please check the input images root path.') + + # Randomly generate the homography for each image to ensure reproducibility + for file in tqdm(self.files): + npz_file = Path(file).with_suffix('.npz') + if not npz_file.exists() or self.overwrite: + image = cv2.imread(file, 0) + image = cv2.resize(image, self.size) + image = np.array(image, dtype=np.float32)/255.0 + + w, h = image.shape[:2] + H = sample_homography(self.size, **homography_params)[0] + warped_image = cv2.warpPerspective(image, H, self.size) + warped_image = np.array(warped_image, dtype=np.float32) + + data = { + 'ref_image': image, + 'target_image': warped_image, + 'homo_mat': H, + } + + np.savez(npz_file, ref_image=image, target_image=warped_image, homo_mat=H) + + self.npz_files.append(npz_file) + + def get_dataset(self): + return self.npz_files + + def get_images(self): + return self.files + + def len_dataset(self): + return len(self.files) + + def __getitem__(self, idx): + npz_file = self.npz_files(idx) + data = np.load(npz_file) + + return data diff --git a/scalelsd/ssl/datasets/dataset_util.py b/scalelsd/ssl/datasets/dataset_util.py new file mode 100644 index 0000000000000000000000000000000000000000..c24642a94b444d2eb1d898de4e240fba54724f5f --- /dev/null +++ b/scalelsd/ssl/datasets/dataset_util.py @@ -0,0 +1,94 @@ +""" +The interface of initializing different datasets. +""" +from .synthetic_dataset import SyntheticShapes,synthetic_collate_fn +from .wireframe_dataset import WireframeDataset,wireframe_collate_fn +from .yorkurban_dataset import YorkUrbanDataset,yorkurban_collate_fn +from .images_dataset import ImageCollections, images_collate_fn +# from .holicity_dataset import HolicityDataset +# from .merge_dataset import MergeDataset +import torch.utils.data.dataloader as torch_loader +try: + from .official_yorkurban_dataset import YorkUrban +except: + pass + +from .nyu_dataset import NYU +from .rdnim_dataset import RDNIM +from .hpatches_dataset import HPatches + +def get_dataset(mode="train", dataset_cfg=None, homoadp=False, **kwargs): + """ Initialize different dataset based on a configuration. """ + # Check dataset config is given + if dataset_cfg is None: + raise ValueError("[Error] The dataset config is required!") + + # Synthetic dataset + if dataset_cfg["dataset_name"] == "synthetic_shape": + dataset = SyntheticShapes( + mode, dataset_cfg + ) + # Get the collate_fn + # from sold2.dataset.synthetic_dataset import synthetic_collate_fn + collate_fn = synthetic_collate_fn + + # Wireframe dataset + elif dataset_cfg["dataset_name"] == "wireframe": + dataset = WireframeDataset( + mode, dataset_cfg + ) + + # Get the collate_fn + collate_fn = wireframe_collate_fn + elif dataset_cfg["dataset_name"] == "yorkurban": + dataset = YorkUrbanDataset( + mode, dataset_cfg + ) + + # Get the collate_fn + collate_fn = yorkurban_collate_fn + # Holicity dataset + elif dataset_cfg["dataset_name"] == "holicity": + dataset = HolicityDataset( + mode, dataset_cfg + ) + + # Get the collate_fn + from sold2.dataset.holicity_dataset import holicity_collate_fn + collate_fn = holicity_collate_fn + + # Dataset merging several datasets in one + elif dataset_cfg["dataset_name"] == "merge": + dataset = MergeDataset( + mode, dataset_cfg + ) + + # Get the collate_fn + from sold2.dataset.holicity_dataset import holicity_collate_fn + collate_fn = holicity_collate_fn + elif dataset_cfg["dataset_name"] == "general": + dataset = ImageCollections(mode, dataset_cfg, homoadp=homoadp,**kwargs) + collate_fn = images_collate_fn + + + ## for the official YorkUrbanDB + elif dataset_cfg["dataset_name"] == "official_yorkurban": + dataset = YorkUrban(mode, dataset_cfg) + collate_fn = torch_loader.default_collate + + ## for the NYU_depth_v2 + elif dataset_cfg["dataset_name"] == "nyu": + dataset = NYU(mode, dataset_cfg) + collate_fn = torch_loader.default_collate + + elif dataset_cfg["dataset_name"] == "rdnim": + dataset = RDNIM(dataset_cfg) + collate_fn = torch_loader.default_collate + elif dataset_cfg["dataset_name"] == "hpatches": + dataset = HPatches(mode, dataset_cfg) + collate_fn = torch_loader.default_collate + else: + raise ValueError( + "[Error] The dataset '%s' is not supported" % dataset_cfg["dataset_name"]) + + return dataset, collate_fn diff --git a/scalelsd/ssl/datasets/hpatches_dataset.py b/scalelsd/ssl/datasets/hpatches_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..b02ce44c38a406837516685ada9ad28176d0ec1e --- /dev/null +++ b/scalelsd/ssl/datasets/hpatches_dataset.py @@ -0,0 +1,126 @@ +""" +HPatches sequences dataset, to perform homography estimation and +evaluate basic line detection metrics. +""" +import os +import numpy as np +import torch +import cv2 +from pathlib import Path +from torch.utils.data import Dataset, DataLoader + +from ..config.project_config import Config as cfg + + + +class HPatches(torch.utils.data.Dataset): + def __init__(self, mode='test', config=None): + assert mode in ['test', 'export'] + + self.conf = config + self.root_dir = Path(cfg.hpatches_dataroot) + folder_paths = [x for x in self.root_dir.iterdir() if x.is_dir()] + self.data = [] + for path in folder_paths: + if config['alteration'] == 'i' and path.stem[0] != 'i': + continue + if config['alteration'] == 'v' and path.stem[0] != 'v': + continue + if mode == 'test': + for i in range(2, 7): + ref_path = Path(path, "1.ppm") + target_path = Path(path, str(i) + '.ppm') + self.data += [{ + "ref_name": str(ref_path.parent.stem + "_" + ref_path.stem), + "ref_img_path": str(ref_path), + "target_name": str(target_path.parent.stem + "_" + target_path.stem), + "target_img_path": str(target_path), + "H": np.loadtxt(str(Path(path, "H_1_" + str(i)))), + }] + else: + for i in range(1, 7): + ref_path = Path(path, str(i) + '.ppm') + self.data += [{ + "ref_name": str(ref_path.parent.stem + "_" + ref_path.stem), + "ref_img_path": str(ref_path)}] + + def get_dataset(self): + return self + + def __getitem__(self, idx): + img0_path = self.data[idx]['ref_img_path'] + img0 = cv2.imread(img0_path, 0) + img_size = img0.shape + + if max(img_size) > self.conf['max_side']: + s = self.conf['max_side'] / max(img_size) + h_s = int(img_size[0] * s) + w_s = int(img_size[1] * s) + img0 = cv2.resize(img0, (w_s, h_s), interpolation=cv2.INTER_AREA) + + # Normalize the image in [0, 1] + img0 = img0.astype(float) / 255. + img0 = torch.tensor(img0[None], dtype=torch.float32) + outputs = {'image': img0, 'image_path': img0_path, + 'name': self.data[idx]['ref_name']} + + if 'target_name' in self.data[idx]: + img1_path = self.data[idx]['target_img_path'] + img1 = cv2.imread(img1_path, 0) + H = self.data[idx]['H'] + + if max(img_size) > self.conf['max_side']: + img1 = cv2.resize(img1, (w_s, h_s), + interpolation=cv2.INTER_AREA) + H = self.adapt_homography_to_preprocessing( + H, img_size, img_size, (h_s, w_s)) + + # Normalize the image in [0, 1] + img1 = img1.astype(float) / 255. + img1 = torch.tensor(img1[None], dtype=torch.float) + H = torch.tensor(H, dtype=torch.float) + + outputs['warped_image'] = img1 + outputs['warped_image_path'] = img1_path + outputs['warped_name'] = self.data[idx]['target_name'] + outputs['H'] = H + + + # root='/home/kezeran/code/hawpv4-dev/data-ssl/0images' + # try: + # cv2.imwrite(f'{root}/img_{idx}.png', cv2.imread(img0_path)) + # cv2.imwrite(f'{root}/img_{idx}_w.png', cv2.imread(img1_path)) + # except: + # pass + + return outputs + + def __len__(self): + return len(self.data) + + def adapt_homography_to_preprocessing(self, H, img_shape1, img_shape2, + target_size): + source_size1 = np.array(img_shape1, dtype=float) + source_size2 = np.array(img_shape2, dtype=float) + target_size = np.array(target_size) + + # Get the scaling factor in resize + scale1 = np.amax(target_size / source_size1) + scaling1 = np.diag([1. / scale1, 1. / scale1, 1.]).astype(float) + scale2 = np.amax(target_size / source_size2) + scaling2 = np.diag([scale2, scale2, 1.]).astype(float) + + # Get the translation params in crop + pad_y1 = (source_size1[0] * scale1 - target_size[0]) / 2. + pad_x1 = (source_size1[1] * scale1 - target_size[1]) / 2. + translation1 = np.array([[1., 0., pad_x1], + [0., 1., pad_y1], + [0., 0., 1.]], dtype=float) + pad_y2 = (source_size2[0] * scale2 - target_size[0]) / 2. + pad_x2 = (source_size2[1] * scale2 - target_size[1]) / 2. + translation2 = np.array([[1., 0., -pad_x2], + [0., 1., -pad_y2], + [0., 0., 1.]], dtype=float) + + return translation2 @ scaling2 @ H @ scaling1 @ translation1 + diff --git a/scalelsd/ssl/datasets/images_dataset.py b/scalelsd/ssl/datasets/images_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..18b54f196acd7f1c45840a2ecef45902a761fb36 --- /dev/null +++ b/scalelsd/ssl/datasets/images_dataset.py @@ -0,0 +1,400 @@ +try: + from pcache_fileio import fileio +except: + pass + +import os +import os.path as osp +import glob +import math +import copy +from skimage.io import imread +from skimage import color +import PIL +from PIL import Image +import numpy as np +import h5py +import cv2 +import pickle +import torch +import torch.utils.data.dataloader as torch_loader +from torch.utils.data import Dataset +from torchvision import transforms +from pathlib import Path +import json + +from ..config.project_config import Config as cfg +from .transforms import photometric_transforms as photoaug +from .transforms import homographic_transforms as homoaug +from .transforms.utils import random_scaling +from .synthetic_util import get_line_heatmap +from ..misc.train_utils import parse_h5_data +from ..misc.geometry_utils import warp_points, mask_points +from tqdm import tqdm + +def images_collate_fn(batch): + """ Customized collate_fn for wireframe dataset. """ + batch_keys = ["image", "junction_map", "valid_mask", "heatmap", + "heatmap_pos", "heatmap_neg", "homography", + "line_points", "line_indices"] + list_keys = ["junctions", "line_map", "line_map_pos", + "line_map_neg", "file_key","fname","image_origin","uuid"] + + outputs = {} + for data_key in batch[0].keys(): + batch_match = sum([_ == data_key for _ in batch_keys]) + list_match = sum([_ == data_key for _ in list_keys]) + # print(batch_match, list_match) + if batch_match > 0 and list_match == 0: + outputs[data_key] = torch_loader.default_collate( + [b[data_key] for b in batch]) + elif batch_match == 0 and list_match > 0: + outputs[data_key] = [b[data_key] for b in batch] + elif batch_match == 0 and list_match == 0: + continue + else: + raise ValueError( + "[Error] A key matches batch keys and list keys simultaneously.") + + return outputs + +class ImageCollections(Dataset): + def __init__(self, mode, config, homoadp=False,homoadp_resume=False): + super(ImageCollections, self).__init__() + if config is None: + self.config = self.get_default_config() + else: + self.config = config + h5path = config.get('gt_source_train',None) + self.json_list = None + self.homoadp = homoadp + self.homoadp_resume = homoadp_resume + + if self.config['img_reg_exp'] == 'all': + self.config['img_reg_exp'] = [] + for i in range(998): + self.config['img_reg_exp'].append(f'sa_{i:06d}/images/*.jpg') + + if h5path is not None and h5path.endswith('.h5'): + self.h5path = osp.join(cfg.EXPORT_ROOT,'export_datasets',h5path) + with h5py.File(self.h5path,'r') as f: + self.filenames = [k.decode('UTF-8') for k in f['filenames']] + self.filenames = [osp.join(cfg.EXPORT_ROOT,f) for f in self.filenames] + elif h5path is not None and h5path.endswith('.jsons'): + self.use_json = True + self.h5path = osp.join(cfg.EXPORT_ROOT,'export_datasets',h5path) + #json_list = glob.glob(self.h5path+'/*.json') + json_list = [] + for exp in tqdm(self.config['img_reg_exp']): + _json_regexp = Path(exp).with_suffix('.json') + _jsons = glob.glob(osp.join(self.h5path,str(_json_regexp))) + json_list.extend(_jsons) + + if cfg.DATASET_ROOT.startswith('pcache'): + + if osp.isfile(Path(h5path).with_suffix('.pcache')) and (self.homoadp_resume or not self.homoadp): + with open(Path(h5path).with_suffix('.pcache'),'r') as _f: + filenames = _f.readlines() + filenames = [ x.rstrip('\n') for x in filenames ] + else: + self.folder_regexp = [] + filenames = [] + print('Loading from pcache......') + for exp in tqdm(self.config['img_reg_exp']): + _path = Path(osp.join(self.config['dataset_root'][0],exp)) + _p = osp.join(cfg.DATASET_ROOT,str(_path.parent)) + + _e = _path.suffix + _list = [osp.basename(_) for _ in os.listdir(_p) if _.endswith(_e)] + _list = [osp.join(_p,_) for _ in _list] + filenames.extend(_list) + + with open(Path(h5path).with_suffix('.pcache'),'w') as _f: + _f.writelines('\n'.join(filenames)) + else: + self.folder_regexp = [osp.join(cfg.DATASET_ROOT,self.config['dataset_root'][0],exp) for exp in self.config['img_reg_exp']] + filenames = sum([glob.glob(exp) for exp in self.folder_regexp],[]) + filenames = [Path(f) for f in filenames] + + self.dataset_root = osp.join(cfg.DATASET_ROOT,self.config['dataset_root'][0]) + filedict = {str(Path(osp.relpath(f,self.dataset_root)).with_suffix('')): f for f in filenames} + + jsondict = {str(Path(osp.relpath(j,h5path)).with_suffix('')): j for j in json_list} + self.filenames = [] + self.json_list = [] + + if self.homoadp: + for k in filedict.keys(): + if k in jsondict and self.homoadp_resume: + continue + else: + self.filenames.append(str(filedict[k])) + self.h5path = None + self.use_json = False + print(f"Found {len(json_list)} json files from the folder") + print(f"Total images are reduced from {len(filenames)} to {len(self.filenames)}") + else: + for k in filedict.keys(): + if k in jsondict: + self.filenames.append(str(filedict[k])) + self.json_list.append(str(jsondict[k])) + else: + self.folder_regexp = [osp.join(cfg.DATASET_ROOT,self.config['dataset_root'][0],exp) for exp in self.config['img_reg_exp']] + self.filenames = sum([glob.glob(exp) for exp in self.folder_regexp],[]) + self.h5path = None + + self.default_config = self.get_default_config() + + self.dataset_name = self.config['alias'] + + self.size = self.config['preprocessing']['resize'] + + print("Found %d images in %s" % (len(self),self.config['dataset_root'])) + + self.num_pad = int(math.ceil(math.log10(len(self))))+1 if len(self)>0 else 0 + + def __len__(self): + return len(self.filenames) + + def get_padded_filename(self, num_pad, idx): + file_len = len("%d" % (idx)) + filename = "0" * (num_pad - file_len) + "%d" % (idx) + return filename + + def train_preprocessing(self, data, numpy=False): + """ Train preprocessing for the dataset. """ + image = data['image'] + junctions = data.get('junctions',None) + image_size = image.shape[:2] + if not(list(image_size) == self.config['preprocessing']['resize']): + size_old = list(image.shape)[:2] + + image = cv2.resize(image, tuple(self.config['preprocessing']['resize'][::-1]), interpolation=cv2.INTER_LINEAR) + + scales = (image.shape[0] / size_old[0], image.shape[1] / size_old[1]) + + if junctions is not None: + junctions *= torch.tensor(scales).reshape(1, 2) + + if self.config['augmentation']['photometric']['enable']: + photo_trans_list = self.get_photo_transform() + ### Apply photometric transforms + np.random.shuffle(photo_trans_list) + image_transform = transforms.Compose(photo_trans_list + [photoaug.normalize_image()]) + else: + image_transform = photoaug.normalize_image() + + image = image_transform(image) + + if self.config['augmentation']['homographic']['enable']: + homo_trans = self.get_homo_transform() + outputs = homo_trans(image, junctions, data['line_map']) + junctions = outputs["junctions"] + image = outputs["warped_image"] + line_map = outputs['line_map'] + data['line_map'] = torch.tensor(line_map) + data['valid_mask'] = outputs['valid_mask'] + + data['image'] = torch.from_numpy(image)[None] + if junctions is not None: + data['junctions'] = torch.from_numpy(junctions).float() + + + return data + + def get_homo_transform(self): + """ Get homographic transforms (according to the config). """ + # Get homographic transforms for image + homo_config = self.config["augmentation"]["homographic"]["params"] + if not self.config["augmentation"]["homographic"]["enable"]: + raise ValueError( + "[Error] Homographic augmentation is not enabled.") + + # Parse the homographic transforms + image_shape = self.config["preprocessing"]["resize"] + + # Compute the min_label_len from config + try: + min_label_tmp = self.config["generation"]["min_label_len"] + except: + min_label_tmp = None + + # float label len => fraction + if isinstance(min_label_tmp, float): # Skip if not provided + min_label_len = min_label_tmp * min(image_shape) + # int label len => length in pixel + elif isinstance(min_label_tmp, int): + scale_ratio = (self.config["preprocessing"]["resize"] + / self.config["generation"]["image_size"][0]) + min_label_len = (self.config["generation"]["min_label_len"] + * scale_ratio) + # if none => no restriction + else: + min_label_len = 0 + + # Initialize the transform + homographic_trans = homoaug.homography_transform( + image_shape, homo_config, 0, min_label_len) + + return homographic_trans + + def get_photo_transform(self): + """ Get list of photometric transforms (according to the config). """ + # Get the photometric transform config + photo_config = self.config["augmentation"]["photometric"] + if not photo_config["enable"]: + raise ValueError( + "[Error] Photometric augmentation is not enabled.") + + # Parse photometric transforms + trans_lst = self.parse_transforms(photo_config["primitives"], + photoaug.available_augmentations) + trans_config_lst = [photo_config["params"].get(p, {}) + for p in trans_lst] + + # List of photometric augmentation + photometric_trans_lst = [ + getattr(photoaug, trans)(**conf) \ + for (trans, conf) in zip(trans_lst, trans_config_lst) + ] + + return photometric_trans_lst + + def parse_transforms(self, names, all_transforms): + """ Parse the transform. """ + trans = all_transforms if (names == 'all') \ + else (names if isinstance(names, list) else [names]) + assert set(trans) <= set(all_transforms) + return trans + + def check_files(self): + h5path = self.config.get('gt_source_train',None) + valid_filenames = [] + for filename in self.filenames: + try: + image_origin = np.array(PIL.Image.open(filename)) + valid_filenames.append(filename) + except IOError: + print(f"Unable to load image from path: {filename}") + + new_pcache_path = Path(h5path).with_name(f"{Path(h5path).stem}_filtered.pcache") + with open(new_pcache_path, 'w') as _f: + for filename in valid_filenames: + _f.write(f"{filename}\n") + + def check_health(self): + is_healthy = True + image_fail_list = [] + json_fail_list = [] + for idx in tqdm(range(len(self))): + #try: + # image_origin = np.array(PIL.Image.open(self.filenames[idx])) + #except: + # is_healthy = False + # print(f'The image {self.filenames[idx]} is broken.') + # image_fail_list.append(self.filenames[idx]) + + if self.h5path is not None and self.json_list is not None: + try: + with open(self.json_list[idx],'r') as f: + data = json.load(f) + except: + is_healthy = False + print(f'The image {self.filenames[idx]} is broken.') + json_fail_list.append(self.json_list[idx]) + return { + 'images': image_fail_list, + 'jsons': json_fail_list, + 'status': is_healthy + } + + def __getitem__(self, idx): + fname = osp.basename(self.filenames[idx]) + #image_origin = cv2.imread(self.filenames[idx]) + try: + image_origin = np.array(PIL.Image.open(self.filenames[idx])) + except: + image_origin = np.array(PIL.Image.open('hawp/ssl/config/exports/sa1b/00030043_0.png')) # deal with the failed case + + if self.config['gray_scale']: + image = cv2.cvtColor(image_origin, cv2.COLOR_BGR2GRAY) + else: + image = cv2.cvtColor(image_origin, cv2.COLOR_BGR2RGB) + + # image = np.array(image,dtype=np.float32)/255.0 + + data = { + 'fname': self.filenames[idx], + 'image': image, + # 'image': torch.from_numpy(image)[None], + # 'valid_mask': torch.ones(self.size,dtype=torch.float32)[None], + 'image_origin': image_origin, + } + data['uuid'] = osp.relpath(self.filenames[idx],self.dataset_root) + + if self.h5path is not None and self.json_list is None: + with h5py.File(self.h5path,'r') as f: + gt_key = self.get_padded_filename(self.num_pad,idx) + exported_label = parse_h5_data(f[gt_key]) + junctions = torch.tensor(exported_label['junctions']).float() + edges = torch.tensor(exported_label['edges']).long() + lines = junctions[edges] + junctions_valid = torch.zeros(len(junctions),dtype=torch.bool) + junctions_valid[edges.unique()] = 1 + junctions_idx = -torch.ones(len(junctions),dtype=torch.long) + junctions_idx[junctions_valid] = torch.arange(junctions_valid.sum()) + edges_remapped = junctions_idx[edges] + junctions = junctions[junctions_valid] + lines_remapped = junctions[edges_remapped] + line_map = torch.zeros(junctions.shape[0],junctions.shape[0],dtype=torch.float32) + if len(edges_remapped) > 0: + line_map[edges_remapped[:,0],edges_remapped[:,1]] = 1 + line_map[edges_remapped[:,1],edges_remapped[:,0]] = 1 + + data['line_map'] = line_map + data['junctions'] = junctions[:,[1,0]] + elif self.h5path is not None and self.json_list is not None: + with open(self.json_list[idx],'r') as f: + json_data = json.load(f) + junctions = torch.tensor(json_data['junctions']).float() + if junctions.shape[0] == 0: + junctions = torch.zeros((1,2)).float() + edges = torch.tensor(json_data['edges']).long() + lines = junctions[edges] + junctions_valid = torch.zeros(len(junctions),dtype=torch.bool) + junctions_valid[edges.unique()] = 1 + junctions_idx = -torch.ones(len(junctions),dtype=torch.long) + junctions_idx[junctions_valid] = torch.arange(junctions_valid.sum()) + edges_remapped = junctions_idx[edges] + junctions = junctions[junctions_valid] + lines_remapped = junctions[edges_remapped] + line_map = torch.zeros(junctions.shape[0],junctions.shape[0],dtype=torch.float32) + if len(edges_remapped) > 0: + line_map[edges_remapped[:,0],edges_remapped[:,1]] = 1 + line_map[edges_remapped[:,1],edges_remapped[:,0]] = 1 + + data['line_map'] = line_map + data['junctions'] = junctions[:,[1,0]] + else: + data['valid_mask'] = torch.ones(self.size,dtype=torch.float32)[None] + + return self.train_preprocessing(data) + return data # TODO: remove this line + + def get_default_config(self): + return { + "dataset_name": "images", + "add_augmentation_to_all_splits": False, + "preprocessing": { + "resize": [512,512], + "blur_size": 11, + }, + "augmentation": { + "photometric": { + "enable": False + }, + "homographic": { + "enable": False + } + } + } \ No newline at end of file diff --git a/scalelsd/ssl/datasets/nyu_dataset.py b/scalelsd/ssl/datasets/nyu_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..9d7333fb42ec2f5d8ad06424976a684989a4b5b4 --- /dev/null +++ b/scalelsd/ssl/datasets/nyu_dataset.py @@ -0,0 +1,118 @@ +""" YorkUrban dataset for VP estimation evaluation. """ + +import os +import csv +import numpy as np +import torch +import cv2 +import scipy.io +from skimage.io import imread +from torch.utils.data import Dataset, DataLoader +from pathlib import Path + +from ..config.project_config import Config as cfg + + +def unproject_vp_to_world(vp, K): + """ Convert the VPs from homogenous format in the image plane + to world direction. """ + proj_vp = (np.linalg.inv(K) @ vp.T).T + proj_vp[:, 1] *= -1 + proj_vp /= np.linalg.norm(proj_vp, axis=1, keepdims=True) + return proj_vp + +class NYU(torch.utils.data.Dataset): + def __init__(self, mode='test', config=None): + + # assert mode in ['val', 'test'] + + # Extract the image names + num_imgs = 1449 + val_size = -49 + + self.root_dir = cfg.nyu_dataroot + self.img_paths = [os.path.join(self.root_dir, 'images', 'nyu_rgb_'+str(i+1).zfill(4) + '.png') + for i in range(num_imgs)] + self.vps_paths = [os.path.join(self.root_dir, 'vps', 'vps_' + str(i).zfill(4) + '.csv') + for i in range(num_imgs)] + self.lines_paths = [os.path.join(self.root_dir, 'labelled_lines', 'labelled_lines_' + str(i).zfill(4) + '.csv') + for i in range(num_imgs)] + self.img_names = [str(i).zfill(4) for i in range(num_imgs)] + + # Separate validation and test + if mode == 'val': + self.img_paths = self.img_paths[-val_size:] + self.vps_paths = self.vps_paths[-val_size:] + self.lines_paths = self.lines_paths[-val_size:] + self.img_names = self.img_names[-val_size:] + elif mode == 'test': + self.img_paths = self.img_paths[:-val_size] + self.vps_paths = self.vps_paths[:-val_size] + self.lines_paths = self.lines_paths[:-val_size] + self.img_names = self.img_names[:-val_size] + + # Load the intrinsics + fx_rgb = 5.1885790117450188e+02 + fy_rgb = 5.1946961112127485e+02 + cx_rgb = 3.2558244941119034e+02 + cy_rgb = 2.5373616633400465e+02 + self.K = torch.tensor([[fx_rgb, 0, cx_rgb], + [0, fy_rgb, cy_rgb], + [0, 0, 1]]) + + def get_dataset(self, split): + return self + + def __getitem__(self, idx): + img_path = self.img_paths[idx] + name = str(Path(img_path).stem) + img = cv2.imread(img_path) + + # Load the GT VPs + vps = [] + with open(self.vps_paths[idx]) as csv_file: + reader = csv.reader(csv_file, delimiter=' ') + for ri, row in enumerate(reader): + if ri == 0: + continue + vps.append([float(row[1]), float(row[2]), 1.]) + vps = unproject_vp_to_world(np.array(vps), self.K.numpy()) + + lines = [] + with open(self.lines_paths[idx]) as csv_file: + reader = csv.reader(csv_file, delimiter=' ') + for ri, row in enumerate(reader): + if ri == 0: + continue + lines.append([float(row[1]), float(row[2]), 1.]) + + # Normalize the images in [0, 1] + # img = img.astype(float) / 255. + + # Convert to torch tensors + # img = torch.tensor(img[None], dtype=torch.float) + vps = torch.tensor(vps, dtype=torch.float) + lines = torch.tensor(lines, dtype=torch.float) + + data = {'image': img, + 'image_path': img_path, + 'name': name, + 'gt_lines': lines, + 'vps': vps, + 'K': self.K + } + + return data + + def __len__(self): + return len(self.img_paths) + + # Overwrite the parent data loader to handle custom split + def get_data_loader(self, split, shuffle=False): + """Return a data loader for a given split.""" + assert split in ['val', 'test', 'export'] + batch_size = self.conf.get(split+'_batch_size') + num_workers = self.conf.get('num_workers', batch_size) + return DataLoader(self.get_dataset(split), batch_size=batch_size, + shuffle=False, pin_memory=True, + num_workers=num_workers) \ No newline at end of file diff --git a/scalelsd/ssl/datasets/official_yorkurban_dataset.py b/scalelsd/ssl/datasets/official_yorkurban_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..cf8ad023ee530fb907250fc44c52c5c82f826444 --- /dev/null +++ b/scalelsd/ssl/datasets/official_yorkurban_dataset.py @@ -0,0 +1,100 @@ +""" YorkUrban dataset for VP estimation evaluation. """ + +import os +import numpy as np +import torch +import cv2 +import scipy.io +from skimage.io import imread +from torch.utils.data import Dataset, DataLoader +from pathlib import Path + +from ..config.project_config import Config as cfg + + + +class YorkUrban(torch.utils.data.Dataset): + def __init__(self, mode='train', config=None): + + assert mode in ['train', 'val', 'test'] + + # Extract the image names + self.root_dir = cfg.official_yorkurban_dataroot + self.img_names = [name for name in os.listdir(self.root_dir) + if os.path.isdir(os.path.join(self.root_dir, name))] + assert len(self.img_names) == 102 ## 102 categories in total + + # Separate validation and test + split_file = os.path.join(self.root_dir, + 'ECCV_TrainingAndTestImageNumbers.mat') + split_mat = scipy.io.loadmat(split_file) + if mode == 'test': + valid_set = split_mat['testSetIndex'][:, 0] - 1 + else: + valid_set = split_mat['trainingSetIndex'][:, 0] - 1 + self.img_names = np.array(self.img_names)[valid_set] + assert len(self.img_names) == 51 + + # Load the intrinsics + K_file = os.path.join(self.root_dir, 'cameraParameters.mat') + K_mat = scipy.io.loadmat(K_file) + f = K_mat['focal'][0, 0] / K_mat['pixelSize'][0, 0] + p_point = K_mat['pp'][0] - 1 # -1 to convert to 0-based conv + self.K = torch.tensor([[f, 0, p_point[0]], + [0, f, p_point[1]], + [0, 0, 1]]) + + def __len__(self): + return len(self.img_names) + + def __getitem__(self, idx): + img_path = os.path.join(self.root_dir, self.img_names[idx], + f'{self.img_names[idx]}.jpg') + name = str(Path(img_path).stem) + img = cv2.imread(img_path) + + # Load the GT lines and VP association + lines_file = os.path.join(self.root_dir, self.img_names[idx], + f'{self.img_names[idx]}LinesAndVP.mat') + lines_mat = scipy.io.loadmat(lines_file) + lines = lines_mat['lines'].reshape(-1, 2, 2)[:, :, [1, 0]] - 1 + vp_association = lines_mat['vp_association'][:, 0] - 1 + + # Load the VPs (non orthogonal ones) + vp_file = os.path.join( + self.root_dir, self.img_names[idx], + f'{self.img_names[idx]}GroundTruthVP_CamParams.mat') + vps = scipy.io.loadmat(vp_file)['vp'].T + + # Keep only the relevant VPs + unique_vps = np.unique(vp_association) + vps = vps[unique_vps] + for i, index in enumerate(unique_vps): + vp_association[vp_association == index] = i + + # Convert to torch tensors + # img = torch.tensor(img[None], dtype=torch.float) + lines = torch.tensor(lines.astype(float), dtype=torch.float) + vps = torch.tensor(vps, dtype=torch.float) + vp_association = torch.tensor(vp_association, dtype=torch.int) + + data = {'image': img, + 'image_path': img_path, + 'name': name, + 'gt_lines': lines, + 'vps': vps, + 'vp_association': vp_association, + 'K': self.K + } + + return data + + # Overwrite the parent data loader to handle custom collate_fn + def get_data_loader(self, split, shuffle=False): + """Return a data loader for a given split.""" + assert split in ['val', 'test'] + batch_size = self.conf.get(split+'_batch_size') + num_workers = self.conf.get('num_workers', batch_size) + return DataLoader(self.get_dataset(split), batch_size=batch_size, + shuffle=False, pin_memory=True, + num_workers=num_workers) diff --git a/scalelsd/ssl/datasets/rdnim_dataset.py b/scalelsd/ssl/datasets/rdnim_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..632043964eef5c6a65d718503b47fcdda9f2d3a3 --- /dev/null +++ b/scalelsd/ssl/datasets/rdnim_dataset.py @@ -0,0 +1,110 @@ +""" Rotated Day-Night Image Matching dataset. """ + +import os +import numpy as np +import torch +import cv2 +import csv +from pathlib import Path +from torch.utils.data import Dataset, DataLoader + +from ..config.project_config import Config as cfg + +def read_timestamps(text_file): + """ + Read a text file containing the timestamps of images + and return a dictionary matching the name of the image + to its timestamp. + """ + timestamps = {'name': [], 'date': [], 'hour': [], + 'minute': [], 'time': []} + with open(text_file, 'r') as csvfile: + reader = csv.reader(csvfile, delimiter=' ') + for row in reader: + timestamps['name'].append(row[0]) + timestamps['date'].append(row[1]) + hour = int(row[2]) + timestamps['hour'].append(hour) + minute = int(row[3]) + timestamps['minute'].append(minute) + timestamps['time'].append(hour + minute / 60.) + return timestamps + +class RDNIM(torch.utils.data.Dataset): + default_conf = { + 'dataset_dir': 'RDNIM', + 'reference': 'day', + } + + def __init__(self, conf): + self._root_dir = Path(cfg.rdnim_dataroot) + ref = conf['reference'] + + # Extract the timestamps + timestamp_files = [p for p + in Path(self._root_dir, 'time_stamps').iterdir()] + timestamps = {} + for f in timestamp_files: + id = f.stem + timestamps[id] = read_timestamps(str(f)) + + # Extract the reference images paths + references = {} + seq_paths = [p for p in Path(self._root_dir, 'references').iterdir()] + for seq in seq_paths: + id = seq.stem + references[id] = str(Path(seq, ref + '.jpg')) + + # Extract the images paths and the homographies + seq_path = [p for p in Path(self._root_dir, 'images').iterdir()] + self._files = [] + for seq in seq_path: + id = seq.stem + images_path = [x for x in seq.iterdir() if x.suffix == '.jpg'] + for img in images_path: + timestamp = timestamps[id]['time'][ + timestamps[id]['name'].index(img.name)] + H = np.loadtxt(str(img)[:-4] + '.txt').astype(float) + self._files.append({ + 'img': str(img), + 'ref': str(references[id]), + 'H': H, + 'timestamp': timestamp}) + + def __getitem__(self, item): + img0_path = self._files[item]['ref'] + img0 = cv2.imread(img0_path, 0) + img1_path = self._files[item]['img'] + img1 = cv2.imread(img1_path, 0) + img_size = img0.shape[:2] + H = self._files[item]['H'] + + # Normalize the images in [0, 1] + img0 = img0.astype(float) / 255. + img1 = img1.astype(float) / 255. + + img0 = torch.tensor(img0[None], dtype=torch.float) + img1 = torch.tensor(img1[None], dtype=torch.float) + H = torch.tensor(H, dtype=torch.float) + + return {'image': img0, 'warped_image': img1, 'H': H, + 'timestamp': self._files[item]['timestamp'], + 'image_path': img0_path, 'warped_image_path': img1_path} + + def __len__(self): + return len(self._files) + + def get_dataset(self, split): + assert split in ['test'] + return self + + # Overwrite the parent data loader to handle custom collate_fn + def get_data_loader(self, split, shuffle=False): + """Return a data loader for a given split.""" + assert split in ['test'] + batch_size = self.conf.get(split+'_batch_size') + num_workers = self.conf.get('num_workers', batch_size) + return DataLoader(self, batch_size=batch_size, + shuffle=shuffle or split == 'train', + pin_memory=True, num_workers=num_workers) + diff --git a/scalelsd/ssl/datasets/synthetic_dataset.py b/scalelsd/ssl/datasets/synthetic_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..eef63999685a8c619fe212f13a72526dc1522fc7 --- /dev/null +++ b/scalelsd/ssl/datasets/synthetic_dataset.py @@ -0,0 +1,746 @@ +""" +This file implements the synthetic shape dataset object for pytorch +""" +from __future__ import print_function +from __future__ import division +from __future__ import absolute_import + +import os +import math +import h5py +import pickle +import torch +import numpy as np +import cv2 +from tqdm import tqdm +from torchvision import transforms +from torch.utils.data import Dataset +import torch.utils.data.dataloader as torch_loader +from ..config.project_config import Config as cfg +from . import synthetic_util +from .transforms import photometric_transforms as photoaug +from .transforms import homographic_transforms as homoaug +from ..misc.train_utils import parse_h5_data + +def synthetic_collate_fn(batch): + """ Customized collate_fn. """ + batch_keys = ["image", "junction_map", "heatmap", + "valid_mask", "homography"] + list_keys = ["junctions", "line_map", "file_key"] + outputs = {} + for data_key in batch[0].keys(): + batch_match = sum([_ in data_key for _ in batch_keys]) + list_match = sum([_ in data_key for _ in list_keys]) + # print(batch_match, list_match) + if batch_match > 0 and list_match == 0: + outputs[data_key] = torch_loader.default_collate([b[data_key] + for b in batch]) + elif batch_match == 0 and list_match > 0: + outputs[data_key] = [b[data_key] for b in batch] + elif batch_match == 0 and list_match == 0: + continue + else: + raise ValueError( + "[Error] A key matches batch keys and list keys simultaneously.") + return outputs + + +class SyntheticShapes(Dataset): + """ Dataset of synthetic shapes. """ + # Initialize the dataset + def __init__(self, mode="train", config=None): + super(SyntheticShapes, self).__init__() + if not mode in ["train", "val", "test"]: + raise ValueError( + "[Error] Supported dataset modes are 'train', 'val', and 'test'.") + self.mode = mode + + # Get configuration + if config is None: + self.config = self.get_default_config() + else: + self.config = config + + # Set all available primitives + self.available_primitives = [ + 'draw_checkerboard_multiseg', + 'draw_lines', + 'draw_polygon', + 'draw_multiple_polygons', + 'draw_star', + 'draw_stripes_multiseg', + 'draw_cube', + 'gaussian_noise' + ] + # Some cache setting + self.dataset_name = self.get_dataset_name() + self.cache_name = self.get_cache_name() + self.cache_path = cfg.synthetic_cache_path + + # Check if export dataset exists + print("===============================================") + self.filename_dataset, self.datapoints = self.construct_dataset() + self.print_dataset_info() + + + self.dataset_length = len(self.datapoints) + # Initialize h5 file handle + self.dataset_path = os.path.join(cfg.synthetic_dataroot, self.dataset_name + ".h5") + + # Fix the random seed for torch and numpy in testing mode + if ((self.mode == "val" or self.mode == "test") + and self.config["add_augmentation_to_all_splits"]): + seed = self.config.get("test_augmentation_seed", 200) + np.random.seed(seed) + torch.manual_seed(seed) + # For CuDNN + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + self.data_prefetch = {} + # # if 1 == 1: + # self.data_prefetch = [] + # for i in range(len(self.datapoints)): + # if i % 100 == 0: + # print('loading %d/%d'%(i,len(self.datapoints))) + # dp = self.datapoints[i] + # with h5py.File(self.dataset_path, "r", swmr=True) as reader: + # data = self.get_data_from_datapoint(dp, reader) + # self.data_prefetch.append(data) + + ########################################## + ## Dataset construction related methods ## + ########################################## + def construct_dataset(self): + """ Dataset constructor. """ + # Check if the filename cache exists + # If cache exists, load from cache + if self._check_dataset_cache(): + print("[Info]: Found filename cache at ...") + print("\t Load filename cache...") + filename_dataset, datapoints = self.get_filename_dataset_from_cache() + print("\t Check if all file exists...") + # If all file exists, continue + if self._check_file_existence(filename_dataset): + print("\t All files exist!") + # If not, need to re-export the synthetic dataset + else: + print("\t Some files are missing. Re-export the synthetic shape dataset.") + self.export_synthetic_shapes() + print("\t Initialize filename dataset") + filename_dataset, datapoints = self.get_filename_dataset() + print("\t Create filename dataset cache...") + self.create_filename_dataset_cache(filename_dataset, + datapoints) + + # If not, initialize dataset from scratch + else: + print("[Info]: Can't find filename cache ...") + print("\t First check export dataset exists.") + # If export dataset exists, then just update the filename_dataset + if self._check_export_dataset(): + print("\t Synthetic dataset exists. Initialize the dataset ...") + + # If export dataset does not exist, export from scratch + else: + print("\t Synthetic dataset does not exist. Export the synthetic dataset.") + self.export_synthetic_shapes() + print("\t Initialize filename dataset") + + filename_dataset, datapoints = self.get_filename_dataset() + print("\t Create filename dataset cache...") + self.create_filename_dataset_cache(filename_dataset, datapoints) + + return filename_dataset, datapoints + + def get_cache_name(self): + """ Get cache name from dataset config / default config. """ + if self.config["dataset_name"] is None: + dataset_name = self.default_config["dataset_name"] + "_%s" % self.mode + else: + dataset_name = self.config["dataset_name"] + "_%s" % self.mode + if self.config.get('alias',None): + dataset_name = dataset_name + '-%s'%self.config['alias'] + + # Compose cache name + cache_name = dataset_name + "_cache.pkl" + + return cache_name + + def get_dataset_name(self): + """Get dataset name from dataset config / default config. """ + if self.config["dataset_name"] is None: + dataset_name = self.default_config["dataset_name"] + "_%s" % self.mode + else: + dataset_name = self.config["dataset_name"] + "_%s" % self.mode + if self.config.get('alias',None): + dataset_name = dataset_name + '-%s'%self.config['alias'] + + return dataset_name + + def get_filename_dataset_from_cache(self): + """ Get filename dataset from cache. """ + # Load from the pkl cache + cache_file_path = os.path.join(self.cache_path, self.cache_name) + with open(cache_file_path, "rb") as f: + data = pickle.load(f) + + return data["filename_dataset"], data["datapoints"] + + def get_filename_dataset(self): + """ Get filename dataset from scratch. """ + # Path to the exported dataset + dataset_path = os.path.join(cfg.synthetic_dataroot, + self.dataset_name + ".h5") + + filename_dataset = {} + datapoints = [] + # Open the h5 dataset + with h5py.File(dataset_path, "r") as f: + # Iterate through all the primitives + for prim_name in f.keys(): + filenames = sorted(f[prim_name].keys()) + filenames_full = [os.path.join(prim_name, _) + for _ in filenames] + + filename_dataset[prim_name] = filenames_full + datapoints += filenames_full + + return filename_dataset, datapoints + + def create_filename_dataset_cache(self, filename_dataset, datapoints): + """ Create filename dataset cache for faster initialization. """ + # Check cache path exists + if not os.path.exists(self.cache_path): + os.makedirs(self.cache_path) + + cache_file_path = os.path.join(self.cache_path, self.cache_name) + data = { + "filename_dataset": filename_dataset, + "datapoints": datapoints + } + with open(cache_file_path, "wb") as f: + pickle.dump(data, f, pickle.HIGHEST_PROTOCOL) + + def export_synthetic_shapes(self): + """ Export synthetic shapes to disk. """ + # Set the global random state for data generation + synthetic_util.set_random_state(np.random.RandomState( + self.config["generation"]["random_seed"])) + # Define the export path + dataset_path = os.path.join(cfg.synthetic_dataroot, + self.dataset_name + ".h5") + + # Open h5py file + with h5py.File(dataset_path, "w", libver="latest") as f: + # Iterate through all types of shape + primitives = self.parse_drawing_primitives( + self.config["primitives"]) + split_size = self.config["generation"]["split_sizes"][self.mode] + for prim in primitives: + # Create h5 group + group = f.create_group(prim) + # Export single primitive + self.export_single_primitive(prim, split_size, group) + + f.swmr_mode = True + + def export_single_primitive(self, primitive, split_size, group): + """ Export single primitive. """ + # Check if the primitive is valid or not + if primitive not in self.available_primitives: + raise ValueError( + "[Error]: %s is not a supported primitive" % primitive) + # Set the random seed + synthetic_util.set_random_state(np.random.RandomState( + self.config["generation"]["random_seed"])) + + # Generate shapes + print("\t Generating %s ..." % primitive) + # folder = f'/home/kezeran/code/hawpv4-dev/data-ssl/synthetic_shapes/{primitive}' + # os.makedirs(folder, exist_ok=True) + for idx in tqdm(range(split_size), ascii=True): + # Generate background image + image = synthetic_util.generate_background( + self.config['generation']['image_size'], + **self.config['generation']['params']['generate_background']) + image = np.zeros(self.config['generation']['image_size'], dtype=np.uint8) + + # Generate points + drawing_func = getattr(synthetic_util, primitive) + kwarg = self.config["generation"]["params"].get(primitive, {}) + + # Get min_len and min_label_len + min_len = self.config["generation"]["min_len"] + min_label_len = self.config["generation"]["min_label_len"] + + # Some only take min_label_len, and gaussian noises take nothing + if primitive in ["draw_lines", "draw_polygon", + "draw_multiple_polygons", "draw_star"]: + data = drawing_func(image, min_len=min_len, + min_label_len=min_label_len, **kwarg) + elif primitive in ["draw_checkerboard_multiseg", + "draw_stripes_multiseg", "draw_cube"]: + data = drawing_func(image, min_label_len=min_label_len, + **kwarg) + else: + data = drawing_func(image, **kwarg) + + # Convert the data + if data["points"] is not None: + points = np.flip(data["points"], axis=1).astype('float') + line_map = data["line_map"].astype(np.int32) + else: + points = np.zeros([0, 2]).astype('float') + line_map = np.zeros([0, 0]).astype(np.int32) + + # Post-processing + blur_size = self.config["preprocessing"]["blur_size"] + image = cv2.GaussianBlur(image, (blur_size, blur_size), 0) + + # Resize the image and the point location. + points = (points + * np.array(self.config['preprocessing']['resize'], + 'float') + / np.array(self.config['generation']['image_size'], + 'float')) + image = cv2.resize( + image, tuple(self.config['preprocessing']['resize'][::-1]), + interpolation=cv2.INTER_LINEAR) + image = np.array(image, dtype=np.uint8) + + # Generate the line heatmap after post-processing + junctions = np.flip(np.round(points).astype(np.int32), axis=1) + heatmap = (synthetic_util.get_line_heatmap( + junctions, line_map, + size=image.shape) * 255.).astype(np.uint8) + + # Record the data in group + num_pad = math.ceil(math.log10(split_size)) + 1 + file_key_name = self.get_padded_filename(num_pad, idx) + file_group = group.create_group(file_key_name) + + # save_path = f'{folder}/{idx}.png' + # cv2.imwrite(save_path, image) + # Store data + file_group.create_dataset("points", data=points, + compression="gzip") + file_group.create_dataset("image", data=image, + compression="gzip") + file_group.create_dataset("line_map", data=line_map, + compression="gzip") + file_group.create_dataset("heatmap", data=heatmap, + compression="gzip") + + def get_default_config(self): + """ Get default configuration of the dataset. """ + # Initialize the default configuration + self.default_config = { + "dataset_name": "synthetic_shape", + "primitives": "all", + "add_augmentation_to_all_splits": False, + # Shape generation configuration + "generation": { + "split_sizes": {'train': 10000, 'val': 400, 'test': 500}, + "random_seed": 10, + "image_size": [960, 1280], + "min_len": 0.09, + "min_label_len": 0.1, + 'params': { + 'generate_background': { + 'min_kernel_size': 150, 'max_kernel_size': 500, + 'min_rad_ratio': 0.02, 'max_rad_ratio': 0.031}, + 'draw_stripes': {'transform_params': (0.1, 0.1)}, + 'draw_multiple_polygons': {'kernel_boundaries': (50, 100)} + }, + }, + # Date preprocessing configuration. + "preprocessing": { + "resize": [240, 320], + "blur_size": 11 + }, + 'augmentation': { + 'photometric': { + 'enable': False, + 'primitives': 'all', + 'params': {}, + 'random_order': True, + }, + 'homographic': { + 'enable': False, + 'params': {}, + 'valid_border_margin': 0, + }, + } + } + + return self.default_config + + def parse_drawing_primitives(self, names): + """ Parse the primitives in config to list of primitive names. """ + if names == "all": + p = self.available_primitives + else: + if isinstance(names, list): + p = names + else: + p = [names] + + assert set(p) <= set(self.available_primitives) + + return p + + @staticmethod + def get_padded_filename(num_pad, idx): + """ Get the padded filename using adaptive padding. """ + file_len = len("%d" % (idx)) + filename = "0" * (num_pad - file_len) + "%d" % (idx) + + return filename + + def print_dataset_info(self): + """ Print dataset info. """ + print("\t ---------Summary------------------") + print("\t Dataset mode: \t\t %s" % self.mode) + print("\t Number of primitive: \t %d" % len(self.filename_dataset.keys())) + print("\t Number of data: \t %d" % len(self.datapoints)) + print("\t ----------------------------------") + + ######################### + ## Pytorch related API ## + ######################### + def get_data_from_datapoint(self, datapoint, reader=None): + """ Get data given the datapoint + (keyname of the h5 dataset e.g. "draw_lines/0000.h5"). """ + # Check if the datapoint is valid + if not datapoint in self.datapoints: + raise ValueError( + "[Error] The specified datapoint is not in available datapoints.") + + # Get data from h5 dataset + if reader is None: + raise ValueError( + "[Error] The reader must be provided in __getitem__.") + else: + data = reader[datapoint] + + return parse_h5_data(data) + + def get_data_from_signature(self, primitive_name, index): + """ Get data given the primitive name and index ("draw_lines", 10) """ + # Check the primitive name and index + self._check_primitive_and_index(primitive_name, index) + + # Get the datapoint from filename dataset + datapoint = self.filename_dataset[primitive_name][index] + + return self.get_data_from_datapoint(datapoint) + + def parse_transforms(self, names, all_transforms): + trans = all_transforms if (names == 'all') \ + else (names if isinstance(names, list) else [names]) + assert set(trans) <= set(all_transforms) + return trans + + def get_photo_transform(self): + """ Get list of photometric transforms (according to the config). """ + # Get the photometric transform config + photo_config = self.config["augmentation"]["photometric"] + if not photo_config["enable"]: + raise ValueError( + "[Error] Photometric augmentation is not enabled.") + + # Parse photometric transforms + trans_lst = self.parse_transforms(photo_config["primitives"], + photoaug.available_augmentations) + trans_config_lst = [photo_config["params"].get(p, {}) + for p in trans_lst] + + # List of photometric augmentation + photometric_trans_lst = [ + getattr(photoaug, trans)(**conf) \ + for (trans, conf) in zip(trans_lst, trans_config_lst) + ] + + return photometric_trans_lst + + def get_homo_transform(self): + """ Get homographic transforms (according to the config). """ + # Get homographic transforms for image + homo_config = self.config["augmentation"]["homographic"]["params"] + if not self.config["augmentation"]["homographic"]["enable"]: + raise ValueError( + "[Error] Homographic augmentation is not enabled") + + # Parse the homographic transforms + # ToDo: use the shape from the config + image_shape = self.config["preprocessing"]["resize"] + + # Compute the min_label_len from config + try: + min_label_tmp = self.config["generation"]["min_label_len"] + except: + min_label_tmp = None + + # float label len => fraction + if isinstance(min_label_tmp, float): # Skip if not provided + min_label_len = min_label_tmp * min(image_shape) + # int label len => length in pixel + elif isinstance(min_label_tmp, int): + scale_ratio = (self.config["preprocessing"]["resize"] + / self.config["generation"]["image_size"][0]) + min_label_len = (self.config["generation"]["min_label_len"] + * scale_ratio) + # if none => no restriction + else: + min_label_len = 0 + + # Initialize the transform + homographic_trans = homoaug.homography_transform( + image_shape, homo_config, 0, min_label_len) + + return homographic_trans + + @staticmethod + def junc_to_junc_map(junctions, image_size): + """ Convert junction points to junction maps. """ + junctions = np.round(junctions).astype(np.int32) + + # Clip the boundary by image size + junctions[:, 0] = np.clip(junctions[:, 0], 0., image_size[0]-1) + junctions[:, 1] = np.clip(junctions[:, 1], 0., image_size[1]-1) + + # Create junction map + junc_map = np.zeros([image_size[0], image_size[1]]) + junc_map[junctions[:, 0], junctions[:, 1]] = 1 + + return junc_map[..., None].astype(np.int32) + + def train_preprocessing(self, data, disable_homoaug=False): + """ Training preprocessing. """ + # Fetch corresponding entries + image = data["image"] + junctions = data["points"] + line_map = data["line_map"] + heatmap = data["heatmap"] + image_size = image.shape[:2] + + # Resize the image before the photometric and homographic transforms + # Check if we need to do the resizing + if not(list(image.shape) == self.config["preprocessing"]["resize"]): + # Resize the image and the point location. + size_old = list(image.shape) + image = cv2.resize( + image, tuple(self.config['preprocessing']['resize'][::-1]), + interpolation=cv2.INTER_LINEAR) + image = np.array(image, dtype=np.uint8) + + junctions = ( + junctions + * np.array(self.config['preprocessing']['resize'], float) + / np.array(size_old, float)) + + # Generate the line heatmap after post-processing + junctions_xy = np.flip(np.round(junctions).astype(np.int32), + axis=1) + heatmap = synthetic_util.get_line_heatmap(junctions_xy, line_map, + size=image.shape) + heatmap = (heatmap * 255.).astype(np.uint8) + + # Update image size + image_size = image.shape[:2] + + # Declare default valid mask (all ones) + valid_mask = np.ones(image_size) + + # Check if we need to apply augmentations + # In training mode => yes. + # In homography adaptation mode (export mode) => No + # Check photometric augmentation + + # import time + # start = time.time() + if self.config["augmentation"]["photometric"]["enable"]: + photo_trans_lst = self.get_photo_transform() + ### Image transform ### + np.random.shuffle(photo_trans_lst) + image_transform = transforms.Compose( + photo_trans_lst + [photoaug.normalize_image()]) + else: + image_transform = photoaug.normalize_image() + # import pdb; pdb.set_trace() + image = image_transform(image) + + # Initialize the empty output dict + outputs = {} + # Convert to tensor and return the results + to_tensor = transforms.ToTensor() + # Check homographic augmentation + if (self.config["augmentation"]["homographic"]["enable"] + and disable_homoaug == False): + homo_trans = self.get_homo_transform() + # Perform homographic transform + homo_outputs = homo_trans(image, junctions, line_map) + + # Record the warped results + junctions = homo_outputs["junctions"] # Should be HW format + image = homo_outputs["warped_image"] + line_map = homo_outputs["line_map"] + # heatmap = homo_outputs["warped_heatmap"] + valid_mask = homo_outputs["valid_mask"] # Same for pos and neg + homography_mat = homo_outputs["homo"] + + # Optionally put warpping information first. + outputs["homography_mat"] = to_tensor( + homography_mat).to(torch.float32)[0, ...] + # end = time.time() - start + # print('photometric transform time: ', end) + + # junction_map = self.junc_to_junc_map(junctions, image_size) + + outputs.update({ + "image": to_tensor(image), + "junctions": to_tensor(np.ascontiguousarray( + junctions).copy()).to(torch.float32)[0, ...], + # "junction_map": to_tensor(junction_map).to(torch.int), + "line_map": to_tensor(line_map).to(torch.int32)[0, ...], + # "heatmap": to_tensor(heatmap).to(torch.int32), + "valid_mask": to_tensor(valid_mask).to(torch.int32), + }) + + return outputs + + def test_preprocessing(self, data): + """ Test preprocessing. """ + # Fetch corresponding entries + image = data["image"] + points = data["points"] + line_map = data["line_map"] + heatmap = data["heatmap"] + image_size = image.shape[:2] + + # Resize the image before the photometric and homographic transforms + if not (list(image.shape) == self.config["preprocessing"]["resize"]): + # Resize the image and the point location. + size_old = list(image.shape) + image = cv2.resize( + image, tuple(self.config['preprocessing']['resize'][::-1]), + interpolation=cv2.INTER_LINEAR) + image = np.array(image, dtype=np.uint8) + + points = (points + * np.array(self.config['preprocessing']['resize'], + float) + / np.array(size_old, float)) + + # Generate the line heatmap after post-processing + junctions = np.flip(np.round(points).astype(np.int32), axis=1) + heatmap = synthetic_util.get_line_heatmap(junctions, line_map, + size=image.shape) + heatmap = (heatmap * 255.).astype(np.uint8) + + # Update image size + image_size = image.shape[:2] + + ### image transform ### + image_transform = photoaug.normalize_image() + image = image_transform(image) + + ### joint transform ### + junction_map = self.junc_to_junc_map(points, image_size) + to_tensor = transforms.ToTensor() + image = to_tensor(image) + junctions = to_tensor(points) + junction_map = to_tensor(junction_map).to(torch.int) + line_map = to_tensor(line_map) + heatmap = to_tensor(heatmap) + valid_mask = to_tensor(np.ones(image_size)).to(torch.int32) + + return { + "image": image, + "junctions": junctions, + "junction_map": junction_map, + "line_map": line_map, + "heatmap": heatmap, + "valid_mask": valid_mask + } + + def __getitem__(self, index): + datapoint = self.datapoints[index] + with h5py.File(self.dataset_path, "r", swmr=True) as reader: + data = self.get_data_from_datapoint(datapoint, reader) + + edges = np.stack(data['line_map'].nonzero()).transpose() + junctions = data['points'][:,::-1] + # lines = junctions[edges].reshape(-1,4) + # image = data['image'] + + # import matplotlib.pyplot as plt + # plt.imshow(image) + # plt.plot([lines[:,0],lines[:,2]],[lines[:,1],lines[:,3]],'r-') + # plt.plot(junctions[:,0],junctions[:,1],'b.') + # plt.show() + # Apply different transforms in different mod. + if (self.mode == "train" + or self.config["add_augmentation_to_all_splits"]): + return_type = self.config.get("return_type", "single") + # data = self.test_preprocessing(data) + data = self.train_preprocessing(data) + else: + data = self.test_preprocessing(data) + + return data + + def __len__(self): + # return len(self.datapoints) + return self.dataset_length + + ######################## + ## Some other methods ## + ######################## + def _check_dataset_cache(self): + """ Check if dataset cache exists. """ + cache_file_path = os.path.join(self.cache_path, self.cache_name) + if os.path.exists(cache_file_path): + return True + else: + return False + + def _check_export_dataset(self): + """ Check if exported dataset exists. """ + dataset_path = os.path.join(cfg.synthetic_dataroot, self.dataset_name) + if os.path.exists(dataset_path) and len(os.listdir(dataset_path)) > 0: + return True + else: + return False + + def _check_file_existence(self, filename_dataset): + """ Check if all exported file exists. """ + # Path to the exported dataset + dataset_path = os.path.join(cfg.synthetic_dataroot, + self.dataset_name + ".h5") + + flag = True + # Open the h5 dataset + with h5py.File(dataset_path, "r") as f: + # Iterate through all the primitives + for prim_name in f.keys(): + if (len(filename_dataset[prim_name]) + != len(f[prim_name].keys())): + flag = False + + return flag + + def _check_primitive_and_index(self, primitive, index): + """ Check if the primitve and index are valid. """ + # Check primitives + if not primitive in self.available_primitives: + raise ValueError( + "[Error] The primitive is not in available primitives.") + + prim_len = len(self.filename_dataset[primitive]) + # Check the index + if not index < prim_len: + raise ValueError( + "[Error] The index exceeds the total file counts %d for %s" + % (prim_len, primitive)) diff --git a/scalelsd/ssl/datasets/synthetic_util.py b/scalelsd/ssl/datasets/synthetic_util.py new file mode 100644 index 0000000000000000000000000000000000000000..a8cd9917b7c2504718159d133037843034ffb85c --- /dev/null +++ b/scalelsd/ssl/datasets/synthetic_util.py @@ -0,0 +1,1243 @@ +""" +Code adapted from https://github.com/rpautrat/SuperPoint +Module used to generate geometrical synthetic shapes +""" +import math +import cv2 as cv +import numpy as np +import shapely.geometry +from itertools import combinations + +random_state = np.random.RandomState(None) + + +def set_random_state(state): + global random_state + random_state = state + + +def get_random_color(background_color): + """ Output a random scalar in grayscale with a least a small contrast + with the background color. """ + color = random_state.randint(256) + if abs(color - background_color) < 30: # not enough contrast + color = (color + 128) % 256 + return color + + +def get_different_color(previous_colors, min_dist=50, max_count=20): + """ Output a color that contrasts with the previous colors. + Parameters: + previous_colors: np.array of the previous colors + min_dist: the difference between the new color and + the previous colors must be at least min_dist + max_count: maximal number of iterations + """ + color = random_state.randint(256) + count = 0 + while np.any(np.abs(previous_colors - color) < min_dist) and count < max_count: + count += 1 + color = random_state.randint(256) + return color + + +def add_salt_and_pepper(img): + """ Add salt and pepper noise to an image. """ + noise = np.zeros((img.shape[0], img.shape[1]), dtype=np.uint8) + cv.randu(noise, 0, 255) + black = noise < 30 + white = noise > 225 + img[white > 0] = 255 + img[black > 0] = 0 + cv.blur(img, (5, 5), img) + return np.empty((0, 2), dtype=np.int32) + + +def generate_background(size=(960, 1280), nb_blobs=100, min_rad_ratio=0.01, + max_rad_ratio=0.05, min_kernel_size=50, + max_kernel_size=300): + """ Generate a customized background image. + Parameters: + size: size of the image + nb_blobs: number of circles to draw + min_rad_ratio: the radius of blobs is at least min_rad_size * max(size) + max_rad_ratio: the radius of blobs is at most max_rad_size * max(size) + min_kernel_size: minimal size of the kernel + max_kernel_size: maximal size of the kernel + """ + img = np.zeros(size, dtype=np.uint8) + dim = max(size) + cv.randu(img, 0, 255) + cv.threshold(img, random_state.randint(256), 255, cv.THRESH_BINARY, img) + background_color = int(np.mean(img)) + blobs = np.concatenate( + [random_state.randint(0, size[1], size=(nb_blobs, 1)), + random_state.randint(0, size[0], size=(nb_blobs, 1))], axis=1) + for i in range(nb_blobs): + col = get_random_color(background_color) + cv.circle(img, (blobs[i][0], blobs[i][1]), + np.random.randint(int(dim * min_rad_ratio), + int(dim * max_rad_ratio)), + col, -1) + kernel_size = random_state.randint(min_kernel_size, max_kernel_size) + cv.blur(img, (kernel_size, kernel_size), img) + return img + + +def generate_custom_background(size, background_color, nb_blobs=3000, + kernel_boundaries=(50, 100)): + """ Generate a customized background to fill the shapes. + Parameters: + background_color: average color of the background image + nb_blobs: number of circles to draw + kernel_boundaries: interval of the possible sizes of the kernel + """ + img = np.zeros(size, dtype=np.uint8) + img = img + get_random_color(background_color) + blobs = np.concatenate( + [np.random.randint(0, size[1], size=(nb_blobs, 1)), + np.random.randint(0, size[0], size=(nb_blobs, 1))], axis=1) + for i in range(nb_blobs): + col = get_random_color(background_color) + cv.circle(img, (blobs[i][0], blobs[i][1]), + np.random.randint(20), col, -1) + kernel_size = np.random.randint(kernel_boundaries[0], + kernel_boundaries[1]) + cv.blur(img, (kernel_size, kernel_size), img) + return img + + +def final_blur(img, kernel_size=(5, 5)): + """ Gaussian blur applied to an image. + Parameters: + kernel_size: size of the kernel + """ + cv.GaussianBlur(img, kernel_size, 0, img) + + +def ccw(A, B, C, dim): + """ Check if the points are listed in counter-clockwise order. """ + if dim == 2: # only 2 dimensions + return((C[:, 1] - A[:, 1]) * (B[:, 0] - A[:, 0]) + > (B[:, 1] - A[:, 1]) * (C[:, 0] - A[:, 0])) + else: # dim should be equal to 3 + return((C[:, 1, :] - A[:, 1, :]) + * (B[:, 0, :] - A[:, 0, :]) + > (B[:, 1, :] - A[:, 1, :]) + * (C[:, 0, :] - A[:, 0, :])) + + +def intersect(A, B, C, D, dim): + """ Return true if line segments AB and CD intersect """ + return np.any((ccw(A, C, D, dim) != ccw(B, C, D, dim)) & + (ccw(A, B, C, dim) != ccw(A, B, D, dim))) + + +def keep_points_inside(points, size): + """ Keep only the points whose coordinates are inside the dimensions of + the image of size 'size' """ + mask = (points[:, 0] >= 0) & (points[:, 0] < size[1]) &\ + (points[:, 1] >= 0) & (points[:, 1] < size[0]) + return points[mask, :] + + +def get_unique_junctions(segments, min_label_len): + """ Get unique junction points from line segments. """ + # Get all junctions from segments + junctions_all = np.concatenate((segments[:, :2], segments[:, 2:]), axis=0) + if junctions_all.shape[0] == 0: + junc_points = None + line_map = None + + # Get all unique junction points + else: + junc_points = np.unique(junctions_all, axis=0) + # Generate line map from points and segments + line_map = get_line_map(junc_points, segments) + + return junc_points, line_map + + +def get_line_map(points: np.ndarray, segments: np.ndarray) -> np.ndarray: + """ Get line map given the points and segment sets. """ + # create empty line map + num_point = points.shape[0] + line_map = np.zeros([num_point, num_point]) + + # Iterate through every segment + for idx in range(segments.shape[0]): + # Get the junctions from a single segement + seg = segments[idx, :] + junction1 = seg[:2] + junction2 = seg[2:] + + # Get index + idx_junction1 = np.where((points == junction1).sum(axis=1) == 2)[0] + idx_junction2 = np.where((points == junction2).sum(axis=1) == 2)[0] + + # label the corresponding entries + line_map[idx_junction1, idx_junction2] = 1 + line_map[idx_junction2, idx_junction1] = 1 + + return line_map + + +def get_line_heatmap(junctions, line_map, size=[480, 640], thickness=1): + """ Get line heat map from junctions and line map. """ + # Make sure that the thickness is 1 + if not isinstance(thickness, int): + thickness = int(thickness) + + # If the junction points are not int => round them and convert to int + if not junctions.dtype == np.int32: + junctions = (np.round(junctions)).astype(np.int32) + + # Initialize empty map + heat_map = np.zeros(size) + + if junctions.shape[0] > 0: # If empty, just return zero map + # Iterate through all the junctions + for idx in range(junctions.shape[0]): + # if no connectivity, just skip it + if line_map[idx, :].sum() == 0: + continue + # Plot the line segment + else: + # Iterate through all the connected junctions + for idx2 in np.where(line_map[idx, :] == 1)[0]: + point1 = junctions[idx, :] + point2 = junctions[idx2, :] + + # Draw line + cv.line(heat_map, tuple(point1), tuple(point2), 1., thickness) + + return heat_map + + +def draw_lines(img, nb_lines=10, min_len=32, min_label_len=32): + """ Draw random lines and output the positions of the pair of junctions + and line associativities. + Parameters: + nb_lines: maximal number of lines + """ + # Set line number and points placeholder + num_lines = random_state.randint(1, nb_lines) + segments = np.empty((0, 4), dtype=np.int32) + points = np.empty((0, 2), dtype=np.int32) + + background_color = int(np.mean(img)) + min_dim = min(img.shape) + + # Convert length constrain to pixel if given float number + if isinstance(min_len, float) and min_len <= 1.: + min_len = int(min_dim * min_len) + if isinstance(min_label_len, float) and min_label_len <= 1.: + min_label_len = int(min_dim * min_label_len) + + # Generate lines one by one + for i in range(num_lines): + x1 = random_state.randint(img.shape[1]) + y1 = random_state.randint(img.shape[0]) + p1 = np.array([[x1, y1]]) + x2 = random_state.randint(img.shape[1]) + y2 = random_state.randint(img.shape[0]) + p2 = np.array([[x2, y2]]) + + # Check the length of the line + line_length = np.sqrt(np.sum((p1 - p2) ** 2)) + if line_length < min_len: + continue + + # Check that there is no overlap + if intersect(segments[:, 0:2], segments[:, 2:4], p1, p2, 2): + continue + + col = get_random_color(background_color) + thickness = random_state.randint(min_dim * 0.01, min_dim * 0.02) + cv.line(img, (x1, y1), (x2, y2), col, thickness) + + # Only record the segments longer than min_label_len + seg_len = math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2) + if seg_len >= min_label_len: + segments = np.concatenate([segments, + np.array([[x1, y1, x2, y2]])], axis=0) + points = np.concatenate([points, + np.array([[x1, y1], [x2, y2]])], axis=0) + + # If no line is drawn, recursively call the function + if points.shape[0] == 0: + return draw_lines(img, nb_lines, min_len, min_label_len) + + # Get the line associativity map + line_map = get_line_map(points, segments) + + return { + "points": points, + "line_map": line_map + } + + +def check_segment_len(segments, min_len=32): + """ Check if one of the segments is too short (True means too short). """ + point1_vec = segments[:, :2] + point2_vec = segments[:, 2:] + diff = point1_vec - point2_vec + + dist = np.sqrt(np.sum(diff ** 2, axis=1)) + if np.any(dist < min_len): + return True + else: + return False + + +def draw_polygon(img, max_sides=8, min_len=32, min_label_len=64): + """ Draw a polygon with a random number of corners and return the position + of the junctions + line map. + Parameters: + max_sides: maximal number of sides + 1 + """ + num_corners = random_state.randint(3, max_sides) + min_dim = min(img.shape[0], img.shape[1]) + rad = max(random_state.rand() * min_dim / 2, min_dim / 10) + # Center of a circle + x = random_state.randint(rad, img.shape[1] - rad) + y = random_state.randint(rad, img.shape[0] - rad) + + # Convert length constrain to pixel if given float number + if isinstance(min_len, float) and min_len <= 1.: + min_len = int(min_dim * min_len) + if isinstance(min_label_len, float) and min_label_len <= 1.: + min_label_len = int(min_dim * min_label_len) + + # Sample num_corners points inside the circle + slices = np.linspace(0, 2 * math.pi, num_corners + 1) + angles = [slices[i] + random_state.rand() * (slices[i+1] - slices[i]) + for i in range(num_corners)] + points = np.array( + [[int(x + max(random_state.rand(), 0.4) * rad * math.cos(a)), + int(y + max(random_state.rand(), 0.4) * rad * math.sin(a))] + for a in angles]) + + # Filter the points that are too close or that have an angle too flat + norms = [np.linalg.norm(points[(i-1) % num_corners, :] + - points[i, :]) for i in range(num_corners)] + mask = np.array(norms) > 0.01 + points = points[mask, :] + num_corners = points.shape[0] + corner_angles = [angle_between_vectors(points[(i-1) % num_corners, :] - + points[i, :], + points[(i+1) % num_corners, :] - + points[i, :]) + for i in range(num_corners)] + mask = np.array(corner_angles) < (2 * math.pi / 3) + points = points[mask, :] + num_corners = points.shape[0] + + # Get junction pairs from points + segments = np.zeros([0, 4]) + # Used to record all the segments no matter we are going to label it or not. + segments_raw = np.zeros([0, 4]) + for idx in range(num_corners): + if idx == (num_corners - 1): + p1 = points[idx] + p2 = points[0] + else: + p1 = points[idx] + p2 = points[idx + 1] + + segment = np.concatenate((p1, p2), axis=0) + # Only record the segments longer than min_label_len + seg_len = np.sqrt(np.sum((p1 - p2) ** 2)) + if seg_len >= min_label_len: + segments = np.concatenate((segments, segment[None, ...]), axis=0) + segments_raw = np.concatenate((segments_raw, segment[None, ...]), + axis=0) + + # If not enough corner, just regenerate one + if (num_corners < 3) or check_segment_len(segments_raw, min_len): + return draw_polygon(img, max_sides, min_len, min_label_len) + + # Get junctions from segments + junctions_all = np.concatenate((segments[:, :2], segments[:, 2:]), axis=0) + if junctions_all.shape[0] == 0: + junc_points = None + line_map = None + + else: + junc_points = np.unique(junctions_all, axis=0) + + # Get the line map + line_map = get_line_map(junc_points, segments) + + corners = points.reshape((-1, 1, 2)) + col = get_random_color(int(np.mean(img))) + cv.fillPoly(img, [corners], col) + + return { + "points": junc_points, + "line_map": line_map + } + + +def overlap(center, rad, centers, rads): + """ Check that the circle with (center, rad) + doesn't overlap with the other circles. """ + flag = False + for i in range(len(rads)): + if np.linalg.norm(center - centers[i]) < rad + rads[i]: + flag = True + break + return flag + + +def angle_between_vectors(v1, v2): + """ Compute the angle (in rad) between the two vectors v1 and v2. """ + v1_u = v1 / np.linalg.norm(v1) + v2_u = v2 / np.linalg.norm(v2) + return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0)) + + +def draw_multiple_polygons(img, max_sides=8, nb_polygons=30, min_len=32, + min_label_len=64, safe_margin=5, **extra): + """ Draw multiple polygons with a random number of corners + and return the junction points + line map. + Parameters: + max_sides: maximal number of sides + 1 + nb_polygons: maximal number of polygons + """ + segments = np.empty((0, 4), dtype=np.int32) + label_segments = np.empty((0, 4), dtype=np.int32) + centers = [] + rads = [] + points = np.empty((0, 2), dtype=np.int32) + background_color = int(np.mean(img)) + + min_dim = min(img.shape[0], img.shape[1]) + # Convert length constrain to pixel if given float number + if isinstance(min_len, float) and min_len <= 1.: + min_len = int(min_dim * min_len) + if isinstance(min_label_len, float) and min_label_len <= 1.: + min_label_len = int(min_dim * min_label_len) + if isinstance(safe_margin, float) and safe_margin <= 1.: + safe_margin = int(min_dim * safe_margin) + + # Sequentially generate polygons + for i in range(nb_polygons): + num_corners = random_state.randint(3, max_sides) + min_dim = min(img.shape[0], img.shape[1]) + + # Also add the real radius + rad = max(random_state.rand() * min_dim / 2, min_dim / 9) + rad_real = rad - safe_margin + + # Center of a circle + x = random_state.randint(rad, img.shape[1] - rad) + y = random_state.randint(rad, img.shape[0] - rad) + + # Sample num_corners points inside the circle + slices = np.linspace(0, 2 * math.pi, num_corners + 1) + angles = [slices[i] + random_state.rand() * (slices[i+1] - slices[i]) + for i in range(num_corners)] + + # Sample outer points and inner points + new_points = [] + new_points_real = [] + for a in angles: + x_offset = max(random_state.rand(), 0.4) + y_offset = max(random_state.rand(), 0.4) + new_points.append([int(x + x_offset * rad * math.cos(a)), + int(y + y_offset * rad * math.sin(a))]) + new_points_real.append( + [int(x + x_offset * rad_real * math.cos(a)), + int(y + y_offset * rad_real * math.sin(a))]) + new_points = np.array(new_points) + new_points_real = np.array(new_points_real) + + # Filter the points that are too close or that have an angle too flat + norms = [np.linalg.norm(new_points[(i-1) % num_corners, :] + - new_points[i, :]) + for i in range(num_corners)] + mask = np.array(norms) > 0.01 + new_points = new_points[mask, :] + new_points_real = new_points_real[mask, :] + + num_corners = new_points.shape[0] + corner_angles = [ + angle_between_vectors(new_points[(i-1) % num_corners, :] - + new_points[i, :], + new_points[(i+1) % num_corners, :] - + new_points[i, :]) + for i in range(num_corners)] + mask = np.array(corner_angles) < (2 * math.pi / 3) + new_points = new_points[mask, :] + new_points_real = new_points_real[mask, :] + num_corners = new_points.shape[0] + + # Not enough corners + if num_corners < 3: + continue + + # Segments for checking overlap (outer circle) + new_segments = np.zeros((1, 4, num_corners)) + new_segments[:, 0, :] = [new_points[i][0] for i in range(num_corners)] + new_segments[:, 1, :] = [new_points[i][1] for i in range(num_corners)] + new_segments[:, 2, :] = [new_points[(i+1) % num_corners][0] + for i in range(num_corners)] + new_segments[:, 3, :] = [new_points[(i+1) % num_corners][1] + for i in range(num_corners)] + + # Segments to record (inner circle) + new_segments_real = np.zeros((1, 4, num_corners)) + new_segments_real[:, 0, :] = [new_points_real[i][0] + for i in range(num_corners)] + new_segments_real[:, 1, :] = [new_points_real[i][1] + for i in range(num_corners)] + new_segments_real[:, 2, :] = [ + new_points_real[(i + 1) % num_corners][0] + for i in range(num_corners)] + new_segments_real[:, 3, :] = [ + new_points_real[(i + 1) % num_corners][1] + for i in range(num_corners)] + + # Check that the polygon will not overlap with pre-existing shapes + if intersect(segments[:, 0:2, None], segments[:, 2:4, None], + new_segments[:, 0:2, :], new_segments[:, 2:4, :], + 3) or overlap(np.array([x, y]), rad, centers, rads): + continue + + # Check that the the edges of the polygon is not too short + if check_segment_len(new_segments_real, min_len): + continue + + # If the polygon is valid, append it to the polygon set + centers.append(np.array([x, y])) + rads.append(rad) + new_segments = np.reshape(np.swapaxes(new_segments, 0, 2), (-1, 4)) + segments = np.concatenate([segments, new_segments], axis=0) + + # Only record the segments longer than min_label_len + new_segments_real = np.reshape(np.swapaxes(new_segments_real, 0, 2), + (-1, 4)) + points1 = new_segments_real[:, :2] + points2 = new_segments_real[:, 2:] + seg_len = np.sqrt(np.sum((points1 - points2) ** 2, axis=1)) + new_label_segment = new_segments_real[seg_len >= min_label_len, :] + label_segments = np.concatenate([label_segments, new_label_segment], + axis=0) + + # Color the polygon with a custom background + corners = new_points_real.reshape((-1, 1, 2)) + mask = np.zeros(img.shape, np.uint8) + custom_background = generate_custom_background( + img.shape, background_color, **extra) + + cv.fillPoly(mask, [corners], 255) + locs = np.where(mask != 0) + img[locs[0], locs[1]] = custom_background[locs[0], locs[1]] + points = np.concatenate([points, new_points], axis=0) + + # Get all junctions from label segments + junctions_all = np.concatenate( + (label_segments[:, :2], label_segments[:, 2:]), axis=0) + if junctions_all.shape[0] == 0: + junc_points = None + line_map = None + + else: + junc_points = np.unique(junctions_all, axis=0) + + # Generate line map from points and segments + line_map = get_line_map(junc_points, label_segments) + + return { + "points": junc_points, + "line_map": line_map + } + + +def draw_ellipses(img, nb_ellipses=20): + """ Draw several ellipses. + Parameters: + nb_ellipses: maximal number of ellipses + """ + centers = np.empty((0, 2), dtype=np.int32) + rads = np.empty((0, 1), dtype=np.int32) + min_dim = min(img.shape[0], img.shape[1]) / 4 + background_color = int(np.mean(img)) + for i in range(nb_ellipses): + ax = int(max(random_state.rand() * min_dim, min_dim / 5)) + ay = int(max(random_state.rand() * min_dim, min_dim / 5)) + max_rad = max(ax, ay) + x = random_state.randint(max_rad, img.shape[1] - max_rad) # center + y = random_state.randint(max_rad, img.shape[0] - max_rad) + new_center = np.array([[x, y]]) + + # Check that the ellipsis will not overlap with pre-existing shapes + diff = centers - new_center + if np.any(max_rad > (np.sqrt(np.sum(diff * diff, axis=1)) - rads)): + continue + centers = np.concatenate([centers, new_center], axis=0) + rads = np.concatenate([rads, np.array([[max_rad]])], axis=0) + + col = get_random_color(background_color) + angle = random_state.rand() * 90 + cv.ellipse(img, (x, y), (ax, ay), angle, 0, 360, col, -1) + return np.empty((0, 2), dtype=np.int32) + + +def draw_star(img, nb_branches=6, min_len=32, min_label_len=64): + """ Draw a star and return the junction points + line map. + Parameters: + nb_branches: number of branches of the star + """ + num_branches = random_state.randint(3, nb_branches) + min_dim = min(img.shape[0], img.shape[1]) + # Convert length constrain to pixel if given float number + if isinstance(min_len, float) and min_len <= 1.: + min_len = int(min_dim * min_len) + if isinstance(min_label_len, float) and min_label_len <= 1.: + min_label_len = int(min_dim * min_label_len) + + thickness = random_state.randint(min_dim * 0.01, min_dim * 0.025) + rad = max(random_state.rand() * min_dim / 2, min_dim / 5) + x = random_state.randint(rad, img.shape[1] - rad) + y = random_state.randint(rad, img.shape[0] - rad) + # Sample num_branches points inside the circle + slices = np.linspace(0, 2 * math.pi, num_branches + 1) + angles = [slices[i] + random_state.rand() * (slices[i+1] - slices[i]) + for i in range(num_branches)] + points = np.array( + [[int(x + max(random_state.rand(), 0.3) * rad * math.cos(a)), + int(y + max(random_state.rand(), 0.3) * rad * math.sin(a))] + for a in angles]) + points = np.concatenate(([[x, y]], points), axis=0) + + # Generate segments and check the length + segments = np.array([[x, y, _[0], _[1]] for _ in points[1:, :]]) + if check_segment_len(segments, min_len): + return draw_star(img, nb_branches, min_len, min_label_len) + + # Only record the segments longer than min_label_len + points1 = segments[:, :2] + points2 = segments[:, 2:] + seg_len = np.sqrt(np.sum((points1 - points2) ** 2, axis=1)) + label_segments = segments[seg_len >= min_label_len, :] + + # Get all junctions from label segments + junctions_all = np.concatenate( + (label_segments[:, :2], label_segments[:, 2:]), axis=0) + if junctions_all.shape[0] == 0: + junc_points = None + line_map = None + + # Get all unique junction points + else: + junc_points = np.unique(junctions_all, axis=0) + # Generate line map from points and segments + line_map = get_line_map(junc_points, label_segments) + + background_color = int(np.mean(img)) + for i in range(1, num_branches + 1): + col = get_random_color(background_color) + cv.line(img, (points[0][0], points[0][1]), + (points[i][0], points[i][1]), + col, thickness) + return { + "points": junc_points, + "line_map": line_map + } + + +def draw_checkerboard_multiseg(img, max_rows=7, max_cols=7, + transform_params=(0.05, 0.15), + min_label_len=64, seed=None): + """ Draw a checkerboard and output the junctions + line segments + Parameters: + max_rows: maximal number of rows + 1 + max_cols: maximal number of cols + 1 + transform_params: set the range of the parameters of the transformations + """ + if seed is None: + global random_state + else: + random_state = np.random.RandomState(seed) + + background_color = int(np.mean(img)) + + min_dim = min(img.shape) + if isinstance(min_label_len, float) and min_label_len <= 1.: + min_label_len = int(min_dim * min_label_len) + # Create the grid + rows = random_state.randint(3, max_rows) # number of rows + cols = random_state.randint(3, max_cols) # number of cols + s = min((img.shape[1] - 1) // cols, (img.shape[0] - 1) // rows) + x_coord = np.tile(range(cols + 1), + rows + 1).reshape(((rows + 1) * (cols + 1), 1)) + y_coord = np.repeat(range(rows + 1), + cols + 1).reshape(((rows + 1) * (cols + 1), 1)) + + # points are the grid coordinates + points = s * np.concatenate([x_coord, y_coord], axis=1) + + # Warp the grid using an affine transformation and an homography + alpha_affine = np.max(img.shape) * ( + transform_params[0] + random_state.rand() * transform_params[1]) + center_square = np.float32(img.shape) // 2 + min_dim = min(img.shape) + square_size = min_dim // 3 + pts1 = np.float32([center_square + square_size, + [center_square[0] + square_size, + center_square[1] - square_size], + center_square - square_size, + [center_square[0] - square_size, + center_square[1] + square_size]]) + pts2 = pts1 + random_state.uniform(-alpha_affine, alpha_affine, + size=pts1.shape).astype(np.float32) + affine_transform = cv.getAffineTransform(pts1[:3], pts2[:3]) + pts2 = pts1 + random_state.uniform(-alpha_affine / 2, alpha_affine / 2, + size=pts1.shape).astype(np.float32) + perspective_transform = cv.getPerspectiveTransform(pts1, pts2) + + # Apply the affine transformation + points = np.transpose(np.concatenate( + (points, np.ones(((rows + 1) * (cols + 1), 1))), axis=1)) + warped_points = np.transpose(np.dot(affine_transform, points)) + + # Apply the homography + warped_col0 = np.add(np.sum(np.multiply( + warped_points, perspective_transform[0, :2]), axis=1), + perspective_transform[0, 2]) + warped_col1 = np.add(np.sum(np.multiply( + warped_points, perspective_transform[1, :2]), axis=1), + perspective_transform[1, 2]) + warped_col2 = np.add(np.sum(np.multiply( + warped_points, perspective_transform[2, :2]), axis=1), + perspective_transform[2, 2]) + warped_col0 = np.divide(warped_col0, warped_col2) + warped_col1 = np.divide(warped_col1, warped_col2) + warped_points = np.concatenate( + [warped_col0[:, None], warped_col1[:, None]], axis=1) + warped_points_float = warped_points.copy() + warped_points = warped_points.astype(int) + + # Fill the rectangles + colors = np.zeros((rows * cols,), np.int32) + + label_segments = [] + for i in range(rows): + for j in range(cols): + # Get a color that contrast with the neighboring cells + if i == 0 and j == 0: + col = get_random_color(background_color) + else: + neighboring_colors = [] + if i != 0: + neighboring_colors.append(colors[(i - 1) * cols + j]) + if j != 0: + neighboring_colors.append(colors[i * cols + j - 1]) + col = get_different_color(np.array(neighboring_colors)) + colors[i * cols + j] = col + # Fill the cell + cv.fillConvexPoly(img, np.array( + [(warped_points[i * (cols + 1) + j, 0], + warped_points[i * (cols + 1) + j, 1]), + (warped_points[i * (cols + 1) + j + 1, 0], + warped_points[i * (cols + 1) + j + 1, 1]), + (warped_points[(i + 1) * (cols + 1) + j + 1, 0], + warped_points[(i + 1) * (cols + 1) + j + 1, 1]), + (warped_points[(i + 1) * (cols + 1) + j, 0], + warped_points[(i + 1) * (cols + 1) + j, 1])]), col) + line1 = np.concatenate([warped_points[i * (cols + 1) + j],warped_points[i * (cols + 1) + j+1]]) + line2 = np.concatenate([warped_points[i * (cols + 1) + j+1],warped_points[(i + 1) * (cols + 1) + j + 1]]) + line3 = np.concatenate([warped_points[(i + 1) * (cols + 1) + j + 1],warped_points[(i + 1) * (cols + 1) + j]]) + line4 = np.concatenate([warped_points[(i + 1) * (cols + 1) + j],warped_points[i * (cols + 1) + j]]) + lines = np.stack([line1,line2,line3,line4]) + label_segments.append(lines) + + label_segments = np.concatenate(label_segments) + """ + label_segments = np.empty([0, 4], dtype=np.int32) + # Iterate through rows + for row_idx in range(rows + 1): + # Include all the combination of the junctions + # Iterate through all the combination of junction index in that row + multi_seg_lst = [ + np.array([warped_points_float[id1, 0], + warped_points_float[id1, 1], + warped_points_float[id2, 0], + warped_points_float[id2, 1]])[None, ...] + for (id1, id2) in combinations(range( + row_idx * (cols + 1), (row_idx + 1) * (cols + 1), 1), 2)] + multi_seg = np.concatenate(multi_seg_lst, axis=0) + label_segments = np.concatenate((label_segments, multi_seg), axis=0) + + # Iterate through columns + for col_idx in range(cols + 1): # for 5 columns, we will have 5 + 1 edges + # Include all the combination of the junctions + # Iterate throuhg all the combination of junction index in that column + multi_seg_lst = [ + np.array([warped_points_float[id1, 0], + warped_points_float[id1, 1], + warped_points_float[id2, 0], + warped_points_float[id2, 1]])[None, ...] + for (id1, id2) in combinations(range( + col_idx, col_idx + ((rows + 1) * (cols + 1)), cols + 1), 2)] + multi_seg = np.concatenate(multi_seg_lst, axis=0) + label_segments = np.concatenate((label_segments, multi_seg), axis=0) + """ + label_segments_filtered = np.zeros([0, 4]) + # Define image boundary polygon (in x y manner) + image_poly = shapely.geometry.Polygon( + [[0, 0], [img.shape[1] - 1, 0], [img.shape[1] - 1, img.shape[0] - 1], + [0, img.shape[0] - 1]]) + for idx in range(label_segments.shape[0]): + # Get the line segment + seg_raw = label_segments[idx, :] + seg = shapely.geometry.LineString([seg_raw[:2], seg_raw[2:]]) + + # The line segment is just inside the image. + if seg.intersection(image_poly) == seg: + label_segments_filtered = np.concatenate( + (label_segments_filtered, seg_raw[None, ...]), axis=0) + + # Intersect with the image. + elif seg.intersects(image_poly): + # Check intersection + try: + p = np.array(seg.intersection( + image_poly).coords).reshape([-1, 4]) + # If intersect with eact one point + except: + continue + segment = p + label_segments_filtered = np.concatenate( + (label_segments_filtered, segment), axis=0) + + else: + continue + + label_segments = np.round(label_segments_filtered).astype(np.int32) + + # Only record the segments longer than min_label_len + points1 = label_segments[:, :2] + points2 = label_segments[:, 2:] + seg_len = np.sqrt(np.sum((points1 - points2) ** 2, axis=1)) + label_segments = label_segments[seg_len >= min_label_len, :] + + # Get all junctions from label segments + junc_points, line_map = get_unique_junctions(label_segments, + min_label_len) + + # Draw lines on the boundaries of the board at random + nb_rows = random_state.randint(2, rows + 2) + nb_cols = random_state.randint(2, cols + 2) + thickness = random_state.randint(min_dim * 0.01, min_dim * 0.015) + for _ in range(nb_rows): + row_idx = random_state.randint(rows + 1) + col_idx1 = random_state.randint(cols + 1) + col_idx2 = random_state.randint(cols + 1) + col = get_random_color(background_color) + cv.line(img, (warped_points[row_idx * (cols + 1) + col_idx1, 0], + warped_points[row_idx * (cols + 1) + col_idx1, 1]), + (warped_points[row_idx * (cols + 1) + col_idx2, 0], + warped_points[row_idx * (cols + 1) + col_idx2, 1]), + col, thickness) + for _ in range(nb_cols): + col_idx = random_state.randint(cols + 1) + row_idx1 = random_state.randint(rows + 1) + row_idx2 = random_state.randint(rows + 1) + col = get_random_color(background_color) + cv.line(img, (warped_points[row_idx1 * (cols + 1) + col_idx, 0], + warped_points[row_idx1 * (cols + 1) + col_idx, 1]), + (warped_points[row_idx2 * (cols + 1) + col_idx, 0], + warped_points[row_idx2 * (cols + 1) + col_idx, 1]), + col, thickness) + + # Keep only the points inside the image + points = keep_points_inside(warped_points, img.shape[:2]) + return { + "points": junc_points, + "line_map": line_map + } + + +def draw_stripes_multiseg(img, max_nb_cols=13, min_len=0.04, min_label_len=64, + transform_params=(0.05, 0.15), seed=None): + """ Draw stripes in a distorted rectangle + and output the junctions points + line map. + Parameters: + max_nb_cols: maximal number of stripes to be drawn + min_width_ratio: the minimal width of a stripe is + min_width_ratio * smallest dimension of the image + transform_params: set the range of the parameters of the transformations + """ + # Set the optional random seed (most for debugging) + if seed is None: + global random_state + else: + random_state = np.random.RandomState(seed) + + background_color = int(np.mean(img)) + # Create the grid + board_size = (int(img.shape[0] * (1 + random_state.rand())), + int(img.shape[1] * (1 + random_state.rand()))) + + # Number of cols + col = random_state.randint(5, max_nb_cols) + cols = np.concatenate([board_size[1] * random_state.rand(col - 1), + np.array([0, board_size[1] - 1])], axis=0) + cols = np.unique(cols.astype(int)) + + # Remove the indices that are too close + min_dim = min(img.shape) + + # Convert length constrain to pixel if given float number + if isinstance(min_len, float) and min_len <= 1.: + min_len = int(min_dim * min_len) + if isinstance(min_label_len, float) and min_label_len <= 1.: + min_label_len = int(min_dim * min_label_len) + + cols = cols[(np.concatenate([cols[1:], + np.array([board_size[1] + min_len])], + axis=0) - cols) >= min_len] + # Update the number of cols + col = cols.shape[0] - 1 + cols = np.reshape(cols, (col + 1, 1)) + cols1 = np.concatenate([cols, np.zeros((col + 1, 1), np.int32)], axis=1) + cols2 = np.concatenate( + [cols, (board_size[0] - 1) * np.ones((col + 1, 1), np.int32)], axis=1) + points = np.concatenate([cols1, cols2], axis=0) + + # Warp the grid using an affine transformation and a homography + alpha_affine = np.max(img.shape) * ( + transform_params[0] + random_state.rand() * transform_params[1]) + center_square = np.float32(img.shape) // 2 + square_size = min(img.shape) // 3 + pts1 = np.float32([center_square + square_size, + [center_square[0]+square_size, + center_square[1]-square_size], + center_square - square_size, + [center_square[0]-square_size, + center_square[1]+square_size]]) + pts2 = pts1 + random_state.uniform(-alpha_affine, alpha_affine, + size=pts1.shape).astype(np.float32) + affine_transform = cv.getAffineTransform(pts1[:3], pts2[:3]) + pts2 = pts1 + random_state.uniform(-alpha_affine / 2, alpha_affine / 2, + size=pts1.shape).astype(np.float32) + perspective_transform = cv.getPerspectiveTransform(pts1, pts2) + + # Apply the affine transformation + points = np.transpose(np.concatenate((points, + np.ones((2 * (col + 1), 1))), + axis=1)) + warped_points = np.transpose(np.dot(affine_transform, points)) + + # Apply the homography + warped_col0 = np.add(np.sum(np.multiply( + warped_points, perspective_transform[0, :2]), axis=1), + perspective_transform[0, 2]) + warped_col1 = np.add(np.sum(np.multiply( + warped_points, perspective_transform[1, :2]), axis=1), + perspective_transform[1, 2]) + warped_col2 = np.add(np.sum(np.multiply( + warped_points, perspective_transform[2, :2]), axis=1), + perspective_transform[2, 2]) + warped_col0 = np.divide(warped_col0, warped_col2) + warped_col1 = np.divide(warped_col1, warped_col2) + warped_points = np.concatenate( + [warped_col0[:, None], warped_col1[:, None]], axis=1) + warped_points_float = warped_points.copy() + warped_points = warped_points.astype(int) + + # Fill the rectangles and get the segments + color = get_random_color(background_color) + # segments_debug = np.zeros([0, 4]) + for i in range(col): + # Fill the color + color = (color + 128 + random_state.randint(-30, 30)) % 256 + cv.fillConvexPoly(img, np.array([(warped_points[i, 0], + warped_points[i, 1]), + (warped_points[i+1, 0], + warped_points[i+1, 1]), + (warped_points[i+col+2, 0], + warped_points[i+col+2, 1]), + (warped_points[i+col+1, 0], + warped_points[i+col+1, 1])]), + color) + + segments = np.zeros([0, 4]) + row = 1 # in stripes case + # Iterate through rows + for row_idx in range(row + 1): + # Include all the combination of the junctions + # Iterate through all the combination of junction index in that row + multi_seg_lst = [np.array( + [warped_points_float[id1, 0], + warped_points_float[id1, 1], + warped_points_float[id2, 0], + warped_points_float[id2, 1]])[None, ...] + for (id1, id2) in combinations(range( + row_idx * (col + 1), (row_idx + 1) * (col + 1), 1), 2)] + multi_seg = np.concatenate(multi_seg_lst, axis=0) + segments = np.concatenate((segments, multi_seg), axis=0) + + # Iterate through columns + for col_idx in range(col + 1): # for 5 columns, we will have 5 + 1 edges. + # Include all the combination of the junctions + # Iterate throuhg all the combination of junction index in that column + multi_seg_lst = [np.array( + [warped_points_float[id1, 0], + warped_points_float[id1, 1], + warped_points_float[id2, 0], + warped_points_float[id2, 1]])[None, ...] + for (id1, id2) in combinations(range( + col_idx, col_idx + (row * col) + 2, col + 1), 2)] + multi_seg = np.concatenate(multi_seg_lst, axis=0) + segments = np.concatenate((segments, multi_seg), axis=0) + + # Select and refine the segments + segments_new = np.zeros([0, 4]) + # Define image boundary polygon (in x y manner) + image_poly = shapely.geometry.Polygon( + [[0, 0], [img.shape[1]-1, 0], [img.shape[1]-1, img.shape[0]-1], + [0, img.shape[0]-1]]) + for idx in range(segments.shape[0]): + # Get the line segment + seg_raw = segments[idx, :] + seg = shapely.geometry.LineString([seg_raw[:2], seg_raw[2:]]) + + # The line segment is just inside the image. + if seg.intersection(image_poly) == seg: + segments_new = np.concatenate( + (segments_new, seg_raw[None, ...]), axis=0) + + # Intersect with the image. + elif seg.intersects(image_poly): + # Check intersection + try: + p = np.array( + seg.intersection(image_poly).coords).reshape([-1, 4]) + # If intersect at exact one point, just continue. + except: + continue + segment = p + segments_new = np.concatenate((segments_new, segment), axis=0) + + else: + continue + + segments = (np.round(segments_new)).astype(np.int32) + + # Only record the segments longer than min_label_len + points1 = segments[:, :2] + points2 = segments[:, 2:] + seg_len = np.sqrt(np.sum((points1 - points2) ** 2, axis=1)) + label_segments = segments[seg_len >= min_label_len, :] + + # Get all junctions from label segments + junctions_all = np.concatenate( + (label_segments[:, :2], label_segments[:, 2:]), axis=0) + if junctions_all.shape[0] == 0: + junc_points = None + line_map = None + + # Get all unique junction points + else: + junc_points = np.unique(junctions_all, axis=0) + # Generate line map from points and segments + line_map = get_line_map(junc_points, label_segments) + + # Draw lines on the boundaries of the stripes at random + nb_rows = random_state.randint(2, 5) + nb_cols = random_state.randint(2, col + 2) + thickness = random_state.randint(min_dim * 0.01, min_dim * 0.011) + for _ in range(nb_rows): + row_idx = random_state.choice([0, col + 1]) + col_idx1 = random_state.randint(col + 1) + col_idx2 = random_state.randint(col + 1) + color = get_random_color(background_color) + cv.line(img, (warped_points[row_idx + col_idx1, 0], + warped_points[row_idx + col_idx1, 1]), + (warped_points[row_idx + col_idx2, 0], + warped_points[row_idx + col_idx2, 1]), + color, thickness) + + for _ in range(nb_cols): + col_idx = random_state.randint(col + 1) + color = get_random_color(background_color) + cv.line(img, (warped_points[col_idx, 0], + warped_points[col_idx, 1]), + (warped_points[col_idx + col + 1, 0], + warped_points[col_idx + col + 1, 1]), + color, thickness) + + # Keep only the points inside the image + # points = keep_points_inside(warped_points, img.shape[:2]) + return { + "points": junc_points, + "line_map": line_map + } + + +def draw_cube(img, min_size_ratio=0.2, min_label_len=64, + scale_interval=(0.4, 0.6), trans_interval=(0.5, 0.2)): + """ Draw a 2D projection of a cube and output the visible juntions. + Parameters: + min_size_ratio: min(img.shape) * min_size_ratio is the smallest + achievable cube side size + scale_interval: the scale is between scale_interval[0] and + scale_interval[0]+scale_interval[1] + trans_interval: the translation is between img.shape*trans_interval[0] + and img.shape*(trans_interval[0] + trans_interval[1]) + """ + # Generate a cube and apply to it an affine transformation + # The order matters! + # The indices of two adjacent vertices differ only of one bit (Gray code) + background_color = int(np.mean(img)) + min_dim = min(img.shape[:2]) + min_side = min_dim * min_size_ratio + lx = min_side + random_state.rand() * 2 * min_dim / 3 # dims of the cube + ly = min_side + random_state.rand() * 2 * min_dim / 3 + lz = min_side + random_state.rand() * 2 * min_dim / 3 + cube = np.array([[0, 0, 0], + [lx, 0, 0], + [0, ly, 0], + [lx, ly, 0], + [0, 0, lz], + [lx, 0, lz], + [0, ly, lz], + [lx, ly, lz]]) + rot_angles = random_state.rand(3) * 3 * math.pi / 10. + math.pi / 10. + rotation_1 = np.array([[math.cos(rot_angles[0]), + -math.sin(rot_angles[0]), 0], + [math.sin(rot_angles[0]), + math.cos(rot_angles[0]), 0], + [0, 0, 1]]) + rotation_2 = np.array([[1, 0, 0], + [0, math.cos(rot_angles[1]), + -math.sin(rot_angles[1])], + [0, math.sin(rot_angles[1]), + math.cos(rot_angles[1])]]) + rotation_3 = np.array([[math.cos(rot_angles[2]), 0, + -math.sin(rot_angles[2])], + [0, 1, 0], + [math.sin(rot_angles[2]), 0, + math.cos(rot_angles[2])]]) + scaling = np.array([[scale_interval[0] + + random_state.rand() * scale_interval[1], 0, 0], + [0, scale_interval[0] + + random_state.rand() * scale_interval[1], 0], + [0, 0, scale_interval[0] + + random_state.rand() * scale_interval[1]]]) + trans = np.array([img.shape[1] * trans_interval[0] + + random_state.randint(-img.shape[1] * trans_interval[1], + img.shape[1] * trans_interval[1]), + img.shape[0] * trans_interval[0] + + random_state.randint(-img.shape[0] * trans_interval[1], + img.shape[0] * trans_interval[1]), + 0]) + cube = trans + np.transpose( + np.dot(scaling, np.dot(rotation_1, + np.dot(rotation_2, np.dot(rotation_3, np.transpose(cube)))))) + + # The hidden corner is 0 by construction + # The front one is 7 + cube = cube[:, :2] # project on the plane z=0 + cube = cube.astype(int) + points = cube[1:, :] # get rid of the hidden corner + + # Get the three visible faces + faces = np.array([[7, 3, 1, 5], [7, 5, 4, 6], [7, 6, 2, 3]]) + + # Get all visible line segments + segments = np.zeros([0, 4]) + # Iterate through all the faces + for face_idx in range(faces.shape[0]): + face = faces[face_idx, :] + # Brute-forcely expand all the segments + segment = np.array( + [np.concatenate((cube[face[0]], cube[face[1]]), axis=0), + np.concatenate((cube[face[1]], cube[face[2]]), axis=0), + np.concatenate((cube[face[2]], cube[face[3]]), axis=0), + np.concatenate((cube[face[3]], cube[face[0]]), axis=0)]) + segments = np.concatenate((segments, segment), axis=0) + + # Select and refine the segments + segments_new = np.zeros([0, 4]) + # Define image boundary polygon (in x y manner) + image_poly = shapely.geometry.Polygon( + [[0, 0], [img.shape[1] - 1, 0], [img.shape[1] - 1, img.shape[0] - 1], + [0, img.shape[0] - 1]]) + for idx in range(segments.shape[0]): + # Get the line segment + seg_raw = segments[idx, :] + seg = shapely.geometry.LineString([seg_raw[:2], seg_raw[2:]]) + + # The line segment is just inside the image. + if seg.intersection(image_poly) == seg: + segments_new = np.concatenate( + (segments_new, seg_raw[None, ...]), axis=0) + + # Intersect with the image. + elif seg.intersects(image_poly): + try: + p = np.array( + seg.intersection(image_poly).coords).reshape([-1, 4]) + except: + continue + segment = p + segments_new = np.concatenate((segments_new, segment), axis=0) + + else: + continue + + segments = (np.round(segments_new)).astype(np.int32) + + # Only record the segments longer than min_label_len + points1 = segments[:, :2] + points2 = segments[:, 2:] + seg_len = np.sqrt(np.sum((points1 - points2) ** 2, axis=1)) + label_segments = segments[seg_len >= min_label_len, :] + + # Get all junctions from label segments + junctions_all = np.concatenate( + (label_segments[:, :2], label_segments[:, 2:]), axis=0) + if junctions_all.shape[0] == 0: + junc_points = None + line_map = None + + # Get all unique junction points + else: + junc_points = np.unique(junctions_all, axis=0) + # Generate line map from points and segments + line_map = get_line_map(junc_points, label_segments) + + # Fill the faces and draw the contours + col_face = get_random_color(background_color) + for i in [0, 1, 2]: + cv.fillPoly(img, [cube[faces[i]].reshape((-1, 1, 2))], + col_face) + thickness = random_state.randint(min_dim * 0.003, min_dim * 0.015) + for i in [0, 1, 2]: + for j in [0, 1, 2, 3]: + col_edge = (col_face + 128 + + random_state.randint(-64, 64))\ + % 256 # color that constrats with the face color + cv.line(img, (cube[faces[i][j], 0], cube[faces[i][j], 1]), + (cube[faces[i][(j + 1) % 4], 0], + cube[faces[i][(j + 1) % 4], 1]), + col_edge, thickness) + + return { + "points": junc_points, + "line_map": line_map + } + + +def gaussian_noise(img): + """ Apply random noise to the image. """ + cv.randu(img, 0, 255) + return { + "points": None, + "line_map": None + } diff --git a/scalelsd/ssl/datasets/transforms/__init__.py b/scalelsd/ssl/datasets/transforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scalelsd/ssl/datasets/transforms/homographic_transforms.py b/scalelsd/ssl/datasets/transforms/homographic_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..885063a23a0c0817464b1ac554340f9e8a2409ee --- /dev/null +++ b/scalelsd/ssl/datasets/transforms/homographic_transforms.py @@ -0,0 +1,359 @@ +""" +This file implements the homographic transforms for data augmentation. +Code adapted from https://github.com/rpautrat/SuperPoint +""" +import numpy as np +from math import pi + +from ..synthetic_util import get_line_map, get_line_heatmap +import cv2 +import copy +import shapely.geometry + + +def sample_homography( + shape, perspective=True, scaling=True, rotation=True, + translation=True, n_scales=5, n_angles=25, scaling_amplitude=0.1, + perspective_amplitude_x=0.1, perspective_amplitude_y=0.1, + patch_ratio=0.5, max_angle=pi/2, allow_artifacts=False, + translation_overflow=0.): + """ + Computes the homography transformation between a random patch in the + original image and a warped projection with the same image size. + As in `tf.contrib.image.transform`, it maps the output point + (warped patch) to a transformed input point (original patch). + The original patch, initialized with a simple half-size centered crop, + is iteratively projected, scaled, rotated and translated. + + Arguments: + shape: A rank-2 `Tensor` specifying the height and width of the original image. + perspective: A boolean that enables the perspective and affine transformations. + scaling: A boolean that enables the random scaling of the patch. + rotation: A boolean that enables the random rotation of the patch. + translation: A boolean that enables the random translation of the patch. + n_scales: The number of tentative scales that are sampled when scaling. + n_angles: The number of tentatives angles that are sampled when rotating. + scaling_amplitude: Controls the amount of scale. + perspective_amplitude_x: Controls the perspective effect in x direction. + perspective_amplitude_y: Controls the perspective effect in y direction. + patch_ratio: Controls the size of the patches used to create the homography. + max_angle: Maximum angle used in rotations. + allow_artifacts: A boolean that enables artifacts when applying the homography. + translation_overflow: Amount of border artifacts caused by translation. + + Returns: + homo_mat: A numpy array of shape `[1, 3, 3]` corresponding to the + homography transform. + selected_scale: The selected scaling factor. + """ + # Convert shape to ndarry + if not isinstance(shape, np.ndarray): + shape = np.array(shape) + + # Corners of the output image + pts1 = np.array([[0., 0.], [0., 1.], [1., 1.], [1., 0.]]) + # Corners of the input patch + margin = (1 - patch_ratio) / 2 + pts2 = margin + np.array([[0, 0], [0, patch_ratio], + [patch_ratio, patch_ratio], [patch_ratio, 0]]) + + # Random perspective and affine perturbations + if perspective: + if not allow_artifacts: + perspective_amplitude_x = min(perspective_amplitude_x, margin) + perspective_amplitude_y = min(perspective_amplitude_y, margin) + + # normal distribution with mean=0, std=perspective_amplitude_y/2 + perspective_displacement = np.random.normal( + 0., perspective_amplitude_y/2, [1]) + h_displacement_left = np.random.normal( + 0., perspective_amplitude_x/2, [1]) + h_displacement_right = np.random.normal( + 0., perspective_amplitude_x/2, [1]) + pts2 += np.stack([np.concatenate([h_displacement_left, + perspective_displacement], 0), + np.concatenate([h_displacement_left, + -perspective_displacement], 0), + np.concatenate([h_displacement_right, + perspective_displacement], 0), + np.concatenate([h_displacement_right, + -perspective_displacement], 0)]) + + # Random scaling: sample several scales, check collision with borders, + # randomly pick a valid one + if scaling: + if scaling_amplitude==-1: + # selected_scale = np.random.uniform(0.5, 2.0) + # [0.25, 0.50, 0.75, 1.25, 1.50, 2.0, 3.0, 4.0] + selected_scale = 1.0 + center = np.mean(pts2, axis=0, keepdims=True) + pts2 = (pts2 - center) * selected_scale + center + else: + scales = np.concatenate( + [[1.], np.random.normal(1, scaling_amplitude/2, [n_scales])], 0) + center = np.mean(pts2, axis=0, keepdims=True) + scaled = (pts2 - center)[None, ...] * scales[..., None, None] + center + # all scales are valid except scale=1 + if allow_artifacts: + valid = np.array(range(n_scales)) + # Chech the valid scale + else: + valid = np.where(np.all((scaled >= 0.) + & (scaled < 1.), (1, 2)))[0] + # No valid scale found => recursively call + if valid.shape[0] == 0: + return sample_homography( + shape, perspective, scaling, rotation, translation, + n_scales, n_angles, scaling_amplitude, + perspective_amplitude_x, perspective_amplitude_y, + patch_ratio, max_angle, allow_artifacts, translation_overflow) + + idx = valid[np.random.uniform(0., valid.shape[0], ()).astype(np.int32)] + pts2 = scaled[idx] + + # Additionally save and return the selected scale. + selected_scale = scales[idx] + + # Random translation + if translation: + t_min, t_max = np.min(pts2, axis=0), np.min(1 - pts2, axis=0) + if allow_artifacts: + t_min += translation_overflow + t_max += translation_overflow + pts2 += (np.stack([np.random.uniform(-t_min[0], t_max[0], ()), + np.random.uniform(-t_min[1], + t_max[1], ())]))[None, ...] + + # Random rotation: sample several rotations, check collision with borders, + # randomly pick a valid one + if rotation: + angles = np.linspace(-max_angle, max_angle, n_angles) + # in case no rotation is valid + angles = np.concatenate([[0.], angles], axis=0) + center = np.mean(pts2, axis=0, keepdims=True) + rot_mat = np.reshape(np.stack( + [np.cos(angles), -np.sin(angles), + np.sin(angles), np.cos(angles)], axis=1), [-1, 2, 2]) + rotated = np.matmul( + np.tile((pts2 - center)[None, ...], [n_angles+1, 1, 1]), + rot_mat) + center + if allow_artifacts: + # All angles are valid, except angle=0 + valid = np.array(range(n_angles)) + else: + valid = np.where(np.all((rotated >= 0.) + & (rotated < 1.), axis=(1, 2)))[0] + + if valid.shape[0] == 0: + return sample_homography( + shape, perspective, scaling, rotation, translation, + n_scales, n_angles, scaling_amplitude, + perspective_amplitude_x, perspective_amplitude_y, + patch_ratio, max_angle, allow_artifacts, translation_overflow) + + idx = valid[np.random.uniform(0., valid.shape[0], + ()).astype(np.int32)] + pts2 = rotated[idx] + + # Rescale to actual size + shape = shape[::-1].astype(np.float32) # different convention [y, x] + pts1 *= shape[None, ...] + pts2 *= shape[None, ...] + + def ax(p, q): return [p[0], p[1], 1, 0, 0, 0, -p[0] * q[0], -p[1] * q[0]] + + def ay(p, q): return [0, 0, 0, p[0], p[1], 1, -p[0] * q[1], -p[1] * q[1]] + + a_mat = np.stack([f(pts1[i], pts2[i]) for i in range(4) + for f in (ax, ay)], axis=0) + p_mat = np.transpose(np.stack([[pts2[i][j] for i in range(4) + for j in range(2)]], axis=0)) + homo_vec, _, _, _ = np.linalg.lstsq(a_mat, p_mat, rcond=None) + + # Compose the homography vector back to matrix + homo_mat = np.concatenate([ + homo_vec[0:3, 0][None, ...], homo_vec[3:6, 0][None, ...], + np.concatenate((homo_vec[6], homo_vec[7], [1]), + axis=0)[None, ...]], axis=0) + + return homo_mat, selected_scale + + +def convert_to_line_segments(junctions, line_map): + """ Convert junctions and line map to line segments. """ + # Copy the line map + line_map_tmp = copy.copy(line_map) + + line_segments = np.zeros([0, 4]) + for idx in range(junctions.shape[0]): + # If no connectivity, just skip it + if line_map_tmp[idx, :].sum() == 0: + continue + # Record the line segment + else: + for idx2 in np.where(line_map_tmp[idx, :] == 1)[0]: + p1 = junctions[idx, :] + p2 = junctions[idx2, :] + line_segments = np.concatenate( + (line_segments, + np.array([p1[0], p1[1], p2[0], p2[1]])[None, ...]), + axis=0) + # Update line_map + line_map_tmp[idx, idx2] = 0 + line_map_tmp[idx2, idx] = 0 + + return line_segments + + +def compute_valid_mask(image_size, homography, + border_margin, valid_mask=None): + # Warp the mask + if valid_mask is None: + initial_mask = np.ones(image_size) + else: + initial_mask = valid_mask + mask = cv2.warpPerspective( + initial_mask, homography, (image_size[1], image_size[0]), + flags=cv2.INTER_NEAREST) + + # Optionally perform erosion + if border_margin > 0: + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, + (border_margin*2, )*2) + mask = cv2.erode(mask, kernel) + + # Perform dilation if border_margin is negative + if border_margin < 0: + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, + (abs(int(border_margin))*2, )*2) + mask = cv2.dilate(mask, kernel) + + return mask + + +def warp_line_segment(line_segments, homography, image_size): + """ Warp the line segments using a homography. """ + # Separate the line segements into 2N points to apply matrix operation + num_segments = line_segments.shape[0] + + junctions = np.concatenate( + (line_segments[:, :2], # The first junction of each segment. + line_segments[:, 2:]), # The second junction of each segment. + axis=0) + # Convert to homogeneous coordinates + # Flip the junctions before converting to homogeneous (xy format) + junctions = np.flip(junctions, axis=1) + junctions = np.concatenate((junctions, np.ones([2*num_segments, 1])), + axis=1) + warped_junctions = np.matmul(homography, junctions.T).T + + # Convert back to segments + warped_junctions = warped_junctions[:, :2] / warped_junctions[:, 2:] + # (Convert back to hw format) + warped_junctions = np.flip(warped_junctions, axis=1) + warped_segments = np.concatenate( + (warped_junctions[:num_segments, :], + warped_junctions[num_segments:, :]), + axis=1 + ) + + # Check the intersections with the boundary + warped_segments_new = np.zeros([0, 4]) + image_poly = shapely.geometry.Polygon( + [[0, 0], [image_size[1]-1, 0], [image_size[1]-1, image_size[0]-1], + [0, image_size[0]-1]]) + for idx in range(warped_segments.shape[0]): + # Get the line segment + seg_raw = warped_segments[idx, :] # in HW format. + # Convert to shapely line (flip to xy format) + seg = shapely.geometry.LineString([np.flip(seg_raw[:2]), + np.flip(seg_raw[2:])]) + + # The line segment is just inside the image. + if seg.intersection(image_poly) == seg: + warped_segments_new = np.concatenate((warped_segments_new, + seg_raw[None, ...]), axis=0) + + # Intersect with the image. + elif seg.intersects(image_poly): + # Check intersection + try: + p = np.array( + seg.intersection(image_poly).coords).reshape([-1, 4]) + # If intersect at exact one point, just continue. + except: + continue + segment = np.concatenate([np.flip(p[0, :2]), np.flip(p[0, 2:], + axis=0)])[None, ...] + warped_segments_new = np.concatenate( + (warped_segments_new, segment), axis=0) + + else: + continue + + # warped_segments = (np.round(warped_segments_new)).astype(np.int32) + + # return warped_segments + return warped_segments_new + + +class homography_transform(object): + """ # Homography transformations. """ + def __init__(self, image_size, homograpy_config, + border_margin=0, min_label_len=20): + self.homo_config = homograpy_config + self.image_size = image_size + self.target_size = (self.image_size[1], self.image_size[0]) + self.border_margin = border_margin + if (min_label_len < 1) and isinstance(min_label_len, float): + raise ValueError("[Error] min_label_len should be in pixels.") + self.min_label_len = min_label_len + + def __call__(self, input_image, junctions, line_map, + valid_mask=None, homo=None, scale=None): + # Sample one random homography or use the given one + if homo is None or scale is None: + homo, scale = sample_homography(self.image_size, + **self.homo_config) + + # Warp the image + warped_image = cv2.warpPerspective( + input_image, homo, self.target_size, flags=cv2.INTER_LINEAR) + + valid_mask = compute_valid_mask(self.image_size, homo, + self.border_margin, valid_mask) + + # Convert junctions and line_map back to line segments + line_segments = convert_to_line_segments(junctions, line_map) + + # Warp the segments and check the length. + # Adjust the min_label_length + warped_segments = warp_line_segment(line_segments, homo, + self.image_size) + + # Convert back to junctions and line_map + junctions_new = np.concatenate((warped_segments[:, :2], + warped_segments[:, 2:]), axis=0) + if junctions_new.shape[0] == 0: + junctions_new = np.zeros([0, 2]) + line_map = np.zeros([0, 0]) + # warped_heatmap = np.zeros(self.image_size) + else: + junctions_new = np.unique(junctions_new, axis=0) + + # Generate line map from points and segments + line_map = get_line_map(junctions_new, + warped_segments).astype(np.int32) + # Compute the heatmap + # warped_heatmap = get_line_heatmap(np.flip(junctions_new, axis=1), + # line_map, self.image_size) + + return { + "junctions": junctions_new, + "warped_image": warped_image, + "valid_mask": valid_mask, + "line_map": line_map, + # "warped_heatmap": warped_heatmap, + "homo": homo, + "scale": scale + } diff --git a/scalelsd/ssl/datasets/transforms/photometric_transforms.py b/scalelsd/ssl/datasets/transforms/photometric_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..8fa44bf0efa93a47e5f8012988058f1cbd49324f --- /dev/null +++ b/scalelsd/ssl/datasets/transforms/photometric_transforms.py @@ -0,0 +1,185 @@ +""" +Common photometric transforms for data augmentation. +""" +import numpy as np +from PIL import Image +from torchvision import transforms as transforms +import cv2 + + +# List all the available augmentations +available_augmentations = [ + 'additive_gaussian_noise', + 'additive_speckle_noise', + 'random_brightness', + 'random_contrast', + 'additive_shade', + 'motion_blur' +] + + +class additive_gaussian_noise(object): + """ Additive gaussian noise. """ + def __init__(self, stddev_range=None): + # If std is not given, use the default setting + if stddev_range is None: + self.stddev_range = [5, 95] + else: + self.stddev_range = stddev_range + + def __call__(self, input_image): + # Get the noise stddev + stddev = np.random.uniform(self.stddev_range[0], self.stddev_range[1]) + noise = np.random.normal(0., stddev, size=input_image.shape) + noisy_image = (input_image + noise).clip(0., 255.) + + return noisy_image + + +class additive_speckle_noise(object): + """ Additive speckle noise. """ + def __init__(self, prob_range=None): + # If prob range is not given, use the default setting + if prob_range is None: + self.prob_range = [0.0, 0.005] + else: + self.prob_range = prob_range + + def __call__(self, input_image): + # Sample + prob = np.random.uniform(self.prob_range[0], self.prob_range[1]) + sample = np.random.uniform(0., 1., size=input_image.shape) + + # Get the mask + mask0 = sample <= prob + mask1 = sample >= (1 - prob) + + # Mask the image (here we assume the image ranges from 0~255 + noisy = input_image.copy() + noisy[mask0] = 0. + noisy[mask1] = 255. + + return noisy + + +class random_brightness(object): + """ Brightness change. """ + def __init__(self, brightness=None): + # If the brightness is not given, use the default setting + if brightness is None: + self.brightness = 0.5 + else: + self.brightness = brightness + + # Initialize the transformer + self.transform = transforms.ColorJitter(brightness=self.brightness) + + def __call__(self, input_image): + # Convert to PIL image + if isinstance(input_image, np.ndarray): + input_image = Image.fromarray(input_image.astype(np.uint8)) + + return np.array(self.transform(input_image)) + + +class random_contrast(object): + """ Additive contrast. """ + def __init__(self, contrast=None): + # If the brightness is not given, use the default setting + if contrast is None: + self.contrast = 0.5 + else: + self.contrast = contrast + + # Initialize the transformer + self.transform = transforms.ColorJitter(contrast=self.contrast) + + def __call__(self, input_image): + # Convert to PIL image + if isinstance(input_image, np.ndarray): + input_image = Image.fromarray(input_image.astype(np.uint8)) + + return np.array(self.transform(input_image)) + + +class additive_shade(object): + """ Additive shade. """ + def __init__(self, nb_ellipses=20, transparency_range=None, + kernel_size_range=None): + self.nb_ellipses = nb_ellipses + if transparency_range is None: + self.transparency_range = [-0.5, 0.8] + else: + self.transparency_range = transparency_range + + if kernel_size_range is None: + self.kernel_size_range = [250, 350] + else: + self.kernel_size_range = kernel_size_range + + def __call__(self, input_image): + # ToDo: if we should convert to numpy array first. + min_dim = min(input_image.shape[:2]) / 4 + mask = np.zeros(input_image.shape[:2], np.uint8) + for i in range(self.nb_ellipses): + ax = int(max(np.random.rand() * min_dim, min_dim / 5)) + ay = int(max(np.random.rand() * min_dim, min_dim / 5)) + max_rad = max(ax, ay) + x = np.random.randint(max_rad, input_image.shape[1] - max_rad) + y = np.random.randint(max_rad, input_image.shape[0] - max_rad) + angle = np.random.rand() * 90 + cv2.ellipse(mask, (x, y), (ax, ay), angle, 0, 360, 255, -1) + + transparency = np.random.uniform(*self.transparency_range) + kernel_size = np.random.randint(*self.kernel_size_range) + + # kernel_size has to be odd + if (kernel_size % 2) == 0: + kernel_size += 1 + mask = cv2.GaussianBlur(mask.astype(np.float32), + (kernel_size, kernel_size), 0) + shaded = (input_image[..., None] + * (1 - transparency * mask[..., np.newaxis]/255.)) + shaded = np.clip(shaded, 0, 255) + + return np.reshape(shaded, input_image.shape) + + +class motion_blur(object): + """ Motion blur. """ + def __init__(self, max_kernel_size=10): + self.max_kernel_size = max_kernel_size + + def __call__(self, input_image): + # Either vertical, horizontal or diagonal blur + mode = np.random.choice(['h', 'v', 'diag_down', 'diag_up']) + ksize = np.random.randint( + 0, int(round((self.max_kernel_size + 1) / 2))) * 2 + 1 + center = int((ksize - 1) / 2) + kernel = np.zeros((ksize, ksize)) + if mode == 'h': + kernel[center, :] = 1. + elif mode == 'v': + kernel[:, center] = 1. + elif mode == 'diag_down': + kernel = np.eye(ksize) + elif mode == 'diag_up': + kernel = np.flip(np.eye(ksize), 0) + var = ksize * ksize / 16. + grid = np.repeat(np.arange(ksize)[:, np.newaxis], ksize, axis=-1) + gaussian = np.exp(-(np.square(grid - center) + + np.square(grid.T - center)) / (2. * var)) + kernel *= gaussian + kernel /= np.sum(kernel) + blurred = cv2.filter2D(input_image, -1, kernel) + + return np.reshape(blurred, input_image.shape) + + +class normalize_image(object): + """ Image normalization to the range [0, 1]. """ + def __init__(self): + self.normalize_value = 255 + + def __call__(self, input_image): + return (input_image / self.normalize_value).astype(np.float32) diff --git a/scalelsd/ssl/datasets/transforms/utils.py b/scalelsd/ssl/datasets/transforms/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..45c2e6fc1eacb631dff0425e7c0fe966a239b518 --- /dev/null +++ b/scalelsd/ssl/datasets/transforms/utils.py @@ -0,0 +1,121 @@ +""" +Some useful functions for dataset pre-processing +""" +import cv2 +import numpy as np +import shapely.geometry as sg + +from ..synthetic_util import get_line_map +from . import homographic_transforms as homoaug + + +def random_scaling(image, junctions, line_map, scale=1., h_crop=0, w_crop=0): + H, W = image.shape[:2] + H_scale, W_scale = round(H * scale), round(W * scale) + + # Nothing to do if the scale is too close to 1 + if H_scale == H and W_scale == W: + return (image, junctions, line_map, np.ones([H, W], dtype=np.int32)) + + # Zoom-in => resize and random crop + if scale >= 1.: + image_big = cv2.resize(image, (W_scale, H_scale), + interpolation=cv2.INTER_LINEAR) + # Crop the image + image = image_big[h_crop:h_crop+H, w_crop:w_crop+W, ...] + valid_mask = np.ones([H, W], dtype=np.int32) + + # Process junctions + junctions, line_map = process_junctions_and_line_map( + h_crop, w_crop, H, W, H_scale, W_scale, + junctions, line_map, "zoom-in") + # Zoom-out => resize and pad + else: + image_shape_raw = image.shape + image_small = cv2.resize(image, (W_scale, H_scale), + interpolation=cv2.INTER_AREA) + # Decide the pasting location + h_start = round((H - H_scale) / 2) + w_start = round((W - W_scale) / 2) + # Paste the image to the middle + image = np.zeros(image_shape_raw, dtype='float') + image[h_start:h_start+H_scale, + w_start:w_start+W_scale, ...] = image_small + valid_mask = np.zeros([H, W], dtype=np.int32) + valid_mask[h_start:h_start+H_scale, w_start:w_start+W_scale] = 1 + + # Process the junctions + junctions, line_map = process_junctions_and_line_map( + h_start, w_start, H, W, H_scale, W_scale, + junctions, line_map, "zoom-out") + + return image, junctions, line_map, valid_mask + + +def process_junctions_and_line_map(h_start, w_start, H, W, H_scale, W_scale, + junctions, line_map, mode="zoom-in"): + if mode == "zoom-in": + junctions[:, 0] = junctions[:, 0] * H_scale / H + junctions[:, 1] = junctions[:, 1] * W_scale / W + line_segments = homoaug.convert_to_line_segments(junctions, line_map) + # Crop segments to the new boundaries + line_segments_new = np.zeros([0, 4]) + image_poly = sg.Polygon( + [[w_start, h_start], + [w_start+W, h_start], + [w_start+W, h_start+H], + [w_start, h_start+H] + ]) + for idx in range(line_segments.shape[0]): + # Get the line segment + seg_raw = line_segments[idx, :] # in HW format. + # Convert to shapely line (flip to xy format) + seg = sg.LineString([np.flip(seg_raw[:2]), + np.flip(seg_raw[2:])]) + # The line segment is just inside the image. + if seg.intersection(image_poly) == seg: + line_segments_new = np.concatenate( + (line_segments_new, seg_raw[None, ...]), axis=0) + # Intersect with the image. + elif seg.intersects(image_poly): + # Check intersection + try: + p = np.array( + seg.intersection(image_poly).coords).reshape([-1, 4]) + # If intersect at exact one point, just continue. + except: + continue + segment = np.concatenate([np.flip(p[0, :2]), np.flip(p[0, 2:], + axis=0)])[None, ...] + line_segments_new = np.concatenate( + (line_segments_new, segment), axis=0) + else: + continue + line_segments_new = (np.round(line_segments_new)).astype(np.int32) + # Filter segments with 0 length + segment_lens = np.linalg.norm( + line_segments_new[:, :2] - line_segments_new[:, 2:], axis=-1) + seg_mask = segment_lens != 0 + line_segments_new = line_segments_new[seg_mask, :] + # Convert back to junctions and line_map + junctions_new = np.concatenate( + (line_segments_new[:, :2], line_segments_new[:, 2:]), axis=0) + if junctions_new.shape[0] == 0: + junctions_new = np.zeros([0, 2]) + line_map = np.zeros([0, 0]) + else: + junctions_new = np.unique(junctions_new, axis=0) + # Generate line map from points and segments + line_map = get_line_map(junctions_new, + line_segments_new).astype(np.int32) + junctions_new[:, 0] -= h_start + junctions_new[:, 1] -= w_start + junctions = junctions_new + elif mode == "zoom-out": + # Process the junctions + junctions[:, 0] = (junctions[:, 0] * H_scale / H) + h_start + junctions[:, 1] = (junctions[:, 1] * W_scale / W) + w_start + else: + raise ValueError("[Error] unknown mode...") + + return junctions, line_map diff --git a/scalelsd/ssl/datasets/wireframe.py b/scalelsd/ssl/datasets/wireframe.py new file mode 100644 index 0000000000000000000000000000000000000000..5218718e56f7a3244545733023adf1bc1fa9de6d --- /dev/null +++ b/scalelsd/ssl/datasets/wireframe.py @@ -0,0 +1,19 @@ +import torch +from torch.utils.data import Dataset + +import os.path as osp +import json +import cv2 +from skimage import io +from PIL import Image +import numpy as np +import random +from torch.utils.data.dataloader import default_collate +from torch.utils.data.dataloader import DataLoader +import matplotlib.pyplot as plt +from torchvision.transforms import functional as F +import copy + + +# class WireframeDataset(Dataset): +# def __init__(self, ) \ No newline at end of file diff --git a/scalelsd/ssl/datasets/wireframe_dataset.py b/scalelsd/ssl/datasets/wireframe_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..c437a1bf97b5f486ad5e6d56724356e25bb1f0ea --- /dev/null +++ b/scalelsd/ssl/datasets/wireframe_dataset.py @@ -0,0 +1,1231 @@ +""" +This file implements the wireframe dataset object for pytorch. +Some parts of the code are adapted from https://github.com/zhou13/lcnn +""" +import os +import math +import copy +from skimage.io import imread +from skimage import color +import PIL +import numpy as np +import h5py +import cv2 +import pickle +import torch +import torch.utils.data.dataloader as torch_loader +from torch.utils.data import Dataset +from torchvision import transforms + +from ..config.project_config import Config as cfg +from .transforms import photometric_transforms as photoaug +from .transforms import homographic_transforms as homoaug +from .transforms.utils import random_scaling +from .synthetic_util import get_line_heatmap +from ..misc.train_utils import parse_h5_data +from ..misc.geometry_utils import warp_points, mask_points +from tqdm import tqdm + +def wireframe_collate_fn(batch): + """ Customized collate_fn for wireframe dataset. """ + batch_keys = ["image", "junction_map", "valid_mask", "heatmap", + "heatmap_pos", "heatmap_neg", "homography", + "line_points", "line_indices"] + list_keys = ["junctions", "line_map", "line_map_pos", + "line_map_neg", "file_key"] + + outputs = {} + for data_key in batch[0].keys(): + batch_match = sum([_ in data_key for _ in batch_keys]) + list_match = sum([_ in data_key for _ in list_keys]) + # print(batch_match, list_match) + if batch_match > 0 and list_match == 0: + outputs[data_key] = torch_loader.default_collate( + [b[data_key] for b in batch]) + elif batch_match == 0 and list_match > 0: + outputs[data_key] = [b[data_key] for b in batch] + elif batch_match == 0 and list_match == 0: + continue + else: + raise ValueError( + "[Error] A key matches batch keys and list keys simultaneously.") + + return outputs + + +class WireframeDataset(Dataset): + def __init__(self, mode="train", config=None): + super(WireframeDataset, self).__init__() + if not mode in ["train", "test"]: + raise ValueError( + "[Error] Unknown mode for Wireframe dataset. Only 'train' and 'test'.") + self.mode = mode + + if config is None: + self.config = self.get_default_config() + else: + self.config = config + # Also get the default config + self.default_config = self.get_default_config() + + # Get cache setting + self.dataset_name = self.get_dataset_name() + self.cache_name = self.get_cache_name() + self.cache_path = cfg.wireframe_cache_path + + # Get the ground truth source + self.gt_source = self.config.get("gt_source_%s"%(self.mode), + "official") + + if not self.gt_source == "official": + # Convert gt_source to full path + self.gt_source = os.path.join(cfg.export_dataroot, self.gt_source) + # Check the full path exists + if not os.path.exists(self.gt_source): + raise ValueError( + "[Error] The specified ground truth source does not exist.") + + + # Get the filename dataset + print("[Info] Initializing wireframe dataset...") + self.filename_dataset, self.datapoints = self.construct_dataset() + + # Get dataset length + self.dataset_length = len(self.datapoints) + # self.dataset_length = len(self.datapoints)//4 + + # Get repeatability evaluation set + if self.mode == "test" and self.config.get("evaluation", None) is not None: + # Get the cache name + tmp = self.cache_name.split(self.mode) + self.rep_i_cache_name = tmp[0] + self.mode + "_rep_i" + tmp[1] + self.rep_v_cache_name = tmp[0] + self.mode + "_rep_v" + tmp[1] + + # Get the repeatability config + self.rep_config = self.config["evaluation"]["repeatability"] + self.rep_eval_dataset = self.construct_rep_eval_dataset() + self.rep_eval_datapoints = self.get_rep_eval_datapoints() + # Print some info + print("[Info] Successfully initialized dataset") + print("\t Name: wireframe") + print("\t Mode: %s" %(self.mode)) + print("\t Gt: %s" %(self.config.get("gt_source_%s"%(self.mode), + "official"))) + print("\t Counts: %d" %(self.dataset_length)) + print("----------------------------------------") + + ####################################### + ## Dataset construction related APIs ## + ####################################### + + def construct_dataset(self): + """ Construct the dataset (from scratch or from cache). """ + # Check if the filename cache exists + # If cache exists, load from cache + if self._check_dataset_cache(): + print("\t Found filename cache %s at %s"%(self.cache_name, + self.cache_path)) + print("\t Load filename cache...") + filename_dataset, datapoints = self.get_filename_dataset_from_cache() + # If not, initialize dataset from scratch + else: + print("\t Can't find filename cache ...") + print("\t Create filename dataset from scratch...") + filename_dataset, datapoints = self.get_filename_dataset() + print("\t Create filename dataset cache...") + self.create_filename_dataset_cache(filename_dataset, datapoints) + + return filename_dataset, datapoints + + def create_filename_dataset_cache(self, filename_dataset, datapoints): + """ Create filename dataset cache for faster initialization. """ + # Check cache path exists + if not os.path.exists(self.cache_path): + os.makedirs(self.cache_path) + + cache_file_path = os.path.join(self.cache_path, self.cache_name) + data = { + "filename_dataset": filename_dataset, + "datapoints": datapoints + } + with open(cache_file_path, "wb") as f: + pickle.dump(data, f, pickle.HIGHEST_PROTOCOL) + + def get_filename_dataset_from_cache(self): + """ Get filename dataset from cache. """ + # Load from pkl cache + cache_file_path = os.path.join(self.cache_path, self.cache_name) + with open(cache_file_path, "rb") as f: + data = pickle.load(f) + + return data["filename_dataset"], data["datapoints"] + + def get_filename_dataset(self): + # Get the path to the dataset + if self.mode == "train": + dataset_path = os.path.join(cfg.wireframe_dataroot, "train") + elif self.mode == "test": + dataset_path = os.path.join(cfg.wireframe_dataroot, "valid") + + # Get paths to all image files + image_paths = sorted([os.path.join(dataset_path, _) + for _ in os.listdir(dataset_path)\ + if os.path.splitext(_)[-1] == ".png"]) + # Get the shared prefix + prefix_paths = [_.split(".png")[0] for _ in image_paths] + + # Get the label paths (different procedure for different split) + if self.mode == "train": + label_paths = [_ + "_label.npz" for _ in prefix_paths] + else: + label_paths = [_ + "_label.npz" for _ in prefix_paths] + mat_paths = [p[:-2] + "_line.mat" for p in prefix_paths] + + # Verify all the images and labels exist + for idx in range(len(image_paths)): + image_path = image_paths[idx] + label_path = label_paths[idx] + if (not (os.path.exists(image_path) + and os.path.exists(label_path))): + raise ValueError( + "[Error] The image and label do not exist. %s"%(image_path)) + # Further verify mat paths for test split + if self.mode == "test": + mat_path = mat_paths[idx] + if not os.path.exists(mat_path): + raise ValueError( + "[Error] The mat file does not exist. %s"%(mat_path)) + + # Construct the filename dataset + num_pad = int(math.ceil(math.log10(len(image_paths))) + 1) + filename_dataset = {} + for idx in range(len(image_paths)): + # Get the file key + key = self.get_padded_filename(num_pad, idx) + + filename_dataset[key] = { + "image": image_paths[idx], + "label": label_paths[idx] + } + + # Get the datapoints + datapoints = list(sorted(filename_dataset.keys())) + + return filename_dataset, datapoints + + def get_dataset_name(self): + """ Get dataset name from dataset config / default config. """ + if self.config["dataset_name"] is None: + dataset_name = self.default_config["dataset_name"] + "_%s" % self.mode + else: + dataset_name = self.config["dataset_name"] + "_%s" % self.mode + + return dataset_name + + def get_cache_name(self): + """ Get cache name from dataset config / default config. """ + if self.config["dataset_name"] is None: + dataset_name = self.default_config["dataset_name"] + "_%s" % self.mode + else: + dataset_name = self.config["dataset_name"] + "_%s" % self.mode + # Compose cache name + cache_name = dataset_name + "_cache.pkl" + + return cache_name + + @staticmethod + def get_padded_filename(num_pad, idx): + """ Get the padded filename using adaptive padding. """ + file_len = len("%d" % (idx)) + filename = "0" * (num_pad - file_len) + "%d" % (idx) + + return filename + + def get_default_config(self): + """ Get the default configuration. """ + return { + "dataset_name": "wireframe", + "add_augmentation_to_all_splits": False, + "preprocessing": { + "resize": [240, 320], + "blur_size": 11 + }, + "augmentation":{ + "photometric":{ + "enable": False + }, + "homographic":{ + "enable": False + }, + }, + } + + ########################################### + ## Repeatability evaluation related APIs ## + ########################################### + # Construct repeatability evaluation dataset (from scratch or from cache) + def construct_rep_eval_dataset(self): + rep_eval_dataset = {} + # Check if viewpoint and illumination cache exists + if self.rep_config["photometric"]["enable"]: + if self._check_rep_eval_dataset_cache(split="i"): + print("\t Found repeatability illumination cache %s at %s"%(self.rep_i_cache_name, self.cache_path)) + print("\t Load repeatability illumination cache...") + rep_i_keymap, rep_i_dataset_name = self.get_rep_eval_dataset_from_cache(split="i") + else: + print("\t Can't find repeatability illumination cache ...") + print("\t Create repeatability illumination dataset from scratch...") + rep_i_keymap, rep_i_dataset_name = self.get_rep_eval_dataset(split="i") + print("\t Create filename dataset cache...") + self.create_rep_eval_dataset_cache("i", rep_i_keymap, rep_i_dataset_name) + else: + rep_i_keymap = None + rep_i_dataset_name = None + + rep_eval_dataset["illumination"] = { + "keymap": rep_i_keymap, + "dataset_name": rep_i_dataset_name + } + + if self.rep_config["homographic"]["enable"]: + if self._check_rep_eval_dataset_cache(split="v"): + print("\t Found repeatability viewpoint cache %s at %s"%(self.rep_v_cache_name, self.cache_path)) + print("\t Load repeatability viewpoint cache...") + rep_v_keymap, rep_v_dataset_name = self.get_rep_eval_dataset_from_cache(split="v") + else: + print("\t Can't find repeatability viewpoint cache ...") + print("\t Create repeatability viewpoint dataset from scratch...") + rep_v_keymap, rep_v_dataset_name = self.get_rep_eval_dataset(split="v") + print("\t Create filename dataset cache...") + self.create_rep_eval_dataset_cache("v", rep_v_keymap, rep_v_dataset_name) + else: + rep_v_keymap = None + rep_v_dataset_name = None + + rep_eval_dataset["viewpoint"] = { + "keymap": rep_v_keymap, + "dataset_name": rep_v_dataset_name + } + + return rep_eval_dataset + + # Create filename dataset cache for faster initialization + def create_rep_eval_dataset_cache(self, split, keymap, dataset_name): + # Check cache path exists + if not os.path.exists(self.cache_path): + os.makedirs(self.cache_path) + + if split == "i": + cache_file_path = os.path.join(self.cache_path, self.rep_i_cache_name) + elif split == "v": + cache_file_path = os.path.join(self.cache_path, self.rep_v_cache_name) + else: + raise ValueError("[Error] Unknown split for repeatability evaluation.") + + data = { + "keymap": keymap, + "dataset_name": dataset_name + } + with open(cache_file_path, "wb") as f: + pickle.dump(data, f, pickle.HIGHEST_PROTOCOL) + + # Get filename dataset from cache + def get_rep_eval_dataset_from_cache(self, split): + # Load from pkl cache + if split == "i": + cache_file_path = os.path.join(self.cache_path, self.rep_i_cache_name) + elif split == "v": + cache_file_path = os.path.join(self.cache_path, self.rep_v_cache_name) + else: + raise ValueError("[Error] Unknown split for repeatability evaluation.") + + with open(cache_file_path, "rb") as f: + data = pickle.load(f) + + return data["keymap"], data["dataset_name"] + + # Initialize the repeatability evaluation dataset from scratch + def get_rep_eval_dataset(self, split): + image_shape = self.config["preprocessing"]["resize"] + + # Initialize the illumination set + if split == "i": + # Set the random seed before continuing + seed = self.rep_config["seed"] + np.random.seed(seed) + torch.manual_seed(seed) + + raise NotImplementedError + + # Initialize the viewpoint set + elif split == "v": + # Set the random seed before continuing + seed = self.rep_config["seed"] + np.random.seed(seed) + torch.manual_seed(seed) + + v_keymap = {} + # Get the name for the output h5 dataset + v_dataset_name = self.rep_v_cache_name.split(".pkl")[0] + ".h5" + v_dataset_path = os.path.join(self.cache_path, v_dataset_name) + with h5py.File(v_dataset_path, "w") as f: + # Iterate through all the file_key in test set + for idx, key in enumerate(tqdm(list(self.filename_dataset.keys()), ascii=True)): + # Sample N random homography + file_key_lst = [] + for i in range(self.rep_config["homographic"]["num_samples"]): + file_key = key + "_" + str(i) + + # Sample a random homography + homo_mat, _ = homoaug.sample_homography(image_shape, + **self.rep_config["homographic"]["params"]) + + file_key_lst.append(file_key) + f.create_dataset(file_key, data=homo_mat, compression="gzip") + + v_keymap[key] = file_key_lst + + return v_keymap, v_dataset_name + + else: + raise ValueError("[Error] Unknow split for repeatability evaluation.") + + # Convert ref image and warped images to list of evaluation pairs + def get_rep_eval_datapoints(self): + datapoints = { + "illumination": [], + "viewpoint": [] + } + + # Iterate through all the ref image + if self.rep_eval_dataset["illumination"]["keymap"] is not None: + for ref_key in sorted(self.rep_eval_dataset["illumination"]["keymap"].keys()): + pair_lst = [[ref_key, _] for _ in self.rep_eval_dataset["illumination"]["keymap"][ref_key]] + datapoints["illumination"] += pair_lst + + if self.rep_eval_dataset["viewpoint"]["keymap"] is not None: + for ref_key in sorted(self.rep_eval_dataset["viewpoint"]["keymap"].keys()): + pair_lst = [[ref_key, _] for _ in self.rep_eval_dataset["viewpoint"]["keymap"][ref_key]] + datapoints["viewpoint"] += pair_lst + + return datapoints + + + ########################################### + ## Repeatability evaluation related APIs ## + ########################################### + # Get the corresponding data according to the "index in rep_eval_datapoints". + def get_rep_eval_data(self, split, idx): + assert split in ["viewpoint", "illumination"] + datapoint = self.rep_eval_datapoints[split][idx] + + # Get reference image + ref_key = datapoint[0] + # Get the data paths + data_path = self.filename_dataset[ref_key] + # Read in the image and npz labels (but haven't applied any transform) + image = imread(data_path["image"]) + + # Resize the image before photometric and homographical augmentations + image_size = image.shape[:2] + if not(list(image_size) == self.config["preprocessing"]["resize"]): + # Resize the image and the point location. + size_old = list(image.shape)[:2] # Only H and W dimensions + + image = cv2.resize(image, tuple(self.config['preprocessing']['resize'][::-1]), + interpolation=cv2.INTER_LINEAR) + image = np.array(image, dtype=np.uint8) + + # Optionally convert the image to grayscale + if self.config["gray_scale"]: + image = (color.rgb2gray(image) *255.).astype(np.uint8) + + image_transform = photoaug.normalize_image() + image = image_transform(image) + + # Get target image + if split == "viewpoint": + target_key = datapoint[1] + dataset_path = os.path.join(self.cache_path, self.rep_eval_dataset[split]["dataset_name"]) + + with h5py.File(dataset_path, "r") as f: + homo_mat = np.array(f[target_key]) + + # Warp the image + target_size = (image.shape[1], image.shape[0]) + target_image = cv2.warpPerspective(image, homo_mat, target_size, + flags=cv2.INTER_LINEAR) + + else: + raise NotImplementedError + + # Convert to tensor and return the results + to_tensor = transforms.ToTensor() + + return { + "ref_image": to_tensor(image), + "ref_key": ref_key, + "target_image": to_tensor(target_image), + "target_key": target_key, + "homo_mat": homo_mat + } + + ############################################ + ## Pytorch and preprocessing related APIs ## + ############################################ + # Get data from the information from filename dataset + @staticmethod + def get_data_from_path(data_path): + output = {} + + # Get image data + image_path = data_path["image"] + image = imread(image_path) + output["image"] = image + + # Get the npz label + """ Data entries in the npz file + jmap: [J, H, W] Junction heat map (H and W are 4x smaller) + joff: [J, 2, H, W] Junction offset within each pixel (Not sure about offsets) + lmap: [H, W] Line heat map with anti-aliasing (H and W are 4x smaller) + junc: [Na, 3] Junction coordinates (coordinates from 0~128 => 4x smaller.) + Lpos: [M, 2] Positive lines represented with junction indices + Lneg: [M, 2] Negative lines represented with junction indices + lpos: [Np, 2, 3] Positive lines represented with junction coordinates + lneg: [Nn, 2, 3] Negative lines represented with junction coordinates + """ + label_path = data_path["label"] + label = np.load(label_path) + for key in list(label.keys()): + output[key] = label[key] + + # If there's "line_mat" entry. + # TODO: How to process mat data + if data_path.get("line_mat") is not None: + raise NotImplementedError + + return output + + @staticmethod + def convert_line_map(lcnn_line_map, num_junctions): + """ Convert the line_pos or line_neg + (represented by two junction indexes) to our line map. """ + # Initialize empty line map + line_map = np.zeros([num_junctions, num_junctions]) + + # Iterate through all the lines + for idx in range(lcnn_line_map.shape[0]): + index1 = lcnn_line_map[idx, 0] + index2 = lcnn_line_map[idx, 1] + + line_map[index1, index2] = 1 + line_map[index2, index1] = 1 + + return line_map + + @staticmethod + def junc_to_junc_map(junctions, image_size): + """ Convert junction points to junction maps. """ + junctions = np.round(junctions).astype(np.int32) + # Clip the boundary by image size + junctions[:, 0] = np.clip(junctions[:, 0], 0., image_size[0]-1) + junctions[:, 1] = np.clip(junctions[:, 1], 0., image_size[1]-1) + + # Create junction map + junc_map = np.zeros([image_size[0], image_size[1]]) + junc_map[junctions[:, 0], junctions[:, 1]] = 1 + + return junc_map[..., None].astype(np.int32) + + def parse_transforms(self, names, all_transforms): + """ Parse the transform. """ + trans = all_transforms if (names == 'all') \ + else (names if isinstance(names, list) else [names]) + assert set(trans) <= set(all_transforms) + return trans + + def get_photo_transform(self): + """ Get list of photometric transforms (according to the config). """ + # Get the photometric transform config + photo_config = self.config["augmentation"]["photometric"] + if not photo_config["enable"]: + raise ValueError( + "[Error] Photometric augmentation is not enabled.") + + # Parse photometric transforms + trans_lst = self.parse_transforms(photo_config["primitives"], + photoaug.available_augmentations) + trans_config_lst = [photo_config["params"].get(p, {}) + for p in trans_lst] + + # List of photometric augmentation + photometric_trans_lst = [ + getattr(photoaug, trans)(**conf) \ + for (trans, conf) in zip(trans_lst, trans_config_lst) + ] + + return photometric_trans_lst + + def get_homo_transform(self): + """ Get homographic transforms (according to the config). """ + # Get homographic transforms for image + homo_config = self.config["augmentation"]["homographic"]["params"] + if not self.config["augmentation"]["homographic"]["enable"]: + raise ValueError( + "[Error] Homographic augmentation is not enabled.") + + # Parse the homographic transforms + image_shape = self.config["preprocessing"]["resize"] + + # Compute the min_label_len from config + try: + min_label_tmp = self.config["generation"]["min_label_len"] + except: + min_label_tmp = None + + # float label len => fraction + if isinstance(min_label_tmp, float): # Skip if not provided + min_label_len = min_label_tmp * min(image_shape) + # int label len => length in pixel + elif isinstance(min_label_tmp, int): + scale_ratio = (self.config["preprocessing"]["resize"] + / self.config["generation"]["image_size"][0]) + min_label_len = (self.config["generation"]["min_label_len"] + * scale_ratio) + # if none => no restriction + else: + min_label_len = 0 + + # Initialize the transform + homographic_trans = homoaug.homography_transform( + image_shape, homo_config, 0, min_label_len) + + return homographic_trans + + def get_line_points(self, junctions, line_map, H1=None, H2=None, + img_size=None, warp=False): + """ Sample evenly points along each line segments + and keep track of line idx. """ + if np.sum(line_map) == 0: + # No segment detected in the image + line_indices = np.zeros(self.config["max_pts"], dtype=int) + line_points = np.zeros((self.config["max_pts"], 2), dtype=float) + return line_points, line_indices + + # Extract all pairs of connected junctions + junc_indices = np.array( + [[i, j] for (i, j) in zip(*np.where(line_map)) if j > i]) + line_segments = np.stack([junctions[junc_indices[:, 0]], + junctions[junc_indices[:, 1]]], axis=1) + # line_segments is (num_lines, 2, 2) + line_lengths = np.linalg.norm( + line_segments[:, 0] - line_segments[:, 1], axis=1) + + # Sample the points separated by at least min_dist_pts along each line + # The number of samples depends on the length of the line + num_samples = np.minimum(line_lengths // self.config["min_dist_pts"], + self.config["max_num_samples"]) + line_points = [] + line_indices = [] + cur_line_idx = 1 + for n in np.arange(2, self.config["max_num_samples"] + 1): + # Consider all lines where we can fit up to n points + cur_line_seg = line_segments[num_samples == n] + line_points_x = np.linspace(cur_line_seg[:, 0, 0], + cur_line_seg[:, 1, 0], + n, axis=-1).flatten() + line_points_y = np.linspace(cur_line_seg[:, 0, 1], + cur_line_seg[:, 1, 1], + n, axis=-1).flatten() + jitter = self.config.get("jittering", 0) + if jitter: + # Add a small random jittering of all points along the line + angles = np.arctan2( + cur_line_seg[:, 1, 0] - cur_line_seg[:, 0, 0], + cur_line_seg[:, 1, 1] - cur_line_seg[:, 0, 1]).repeat(n) + jitter_hyp = (np.random.rand(len(angles)) * 2 - 1) * jitter + line_points_x += jitter_hyp * np.sin(angles) + line_points_y += jitter_hyp * np.cos(angles) + line_points.append(np.stack([line_points_x, line_points_y], axis=-1)) + # Keep track of the line indices for each sampled point + num_cur_lines = len(cur_line_seg) + line_idx = np.arange(cur_line_idx, cur_line_idx + num_cur_lines) + line_indices.append(line_idx.repeat(n)) + cur_line_idx += num_cur_lines + line_points = np.concatenate(line_points, + axis=0)[:self.config["max_pts"]] + line_indices = np.concatenate(line_indices, + axis=0)[:self.config["max_pts"]] + + # Warp the points if need be, and filter unvalid ones + # If the other view is also warped + if warp and H2 is not None: + warp_points2 = warp_points(line_points, H2) + line_points = warp_points(line_points, H1) + mask = mask_points(line_points, img_size) + mask2 = mask_points(warp_points2, img_size) + mask = mask * mask2 + # If the other view is not warped + elif warp and H2 is None: + line_points = warp_points(line_points, H1) + mask = mask_points(line_points, img_size) + else: + if H1 is not None: + raise ValueError("[Error] Wrong combination of homographies.") + # Remove points that would be outside of img_size if warped by H + warped_points = warp_points(line_points, H1) + mask = mask_points(warped_points, img_size) + line_points = line_points[mask] + line_indices = line_indices[mask] + + # Pad the line points to a fixed length + # Index of 0 means padded line + line_indices = np.concatenate([line_indices, np.zeros( + self.config["max_pts"] - len(line_indices))], axis=0) + line_points = np.concatenate( + [line_points, + np.zeros((self.config["max_pts"] - len(line_points), 2), + dtype=float)], axis=0) + + return line_points, line_indices + + def train_preprocessing(self, data, numpy=False): + """ Train preprocessing for GT data. """ + # Fetch the corresponding entries + image = data["image"] + junctions = data["junc"][:, :2] + line_pos = data["Lpos"] + line_neg = data["Lneg"] + image_size = image.shape[:2] + # Convert junctions to pixel coordinates (from 128x128) + junctions[:, 0] *= image_size[0] / 128 + junctions[:, 1] *= image_size[1] / 128 + + # Resize the image before photometric and homographical augmentations + if not(list(image_size) == self.config["preprocessing"]["resize"]): + # Resize the image and the point location. + size_old = list(image.shape)[:2] # Only H and W dimensions + + image = cv2.resize( + image, tuple(self.config['preprocessing']['resize'][::-1]), + interpolation=cv2.INTER_LINEAR) + image = np.array(image, dtype=np.uint8) + + # In HW format + junctions = (junctions * np.array( + self.config['preprocessing']['resize'], np.float) + / np.array(size_old, np.float)) + + # Convert to positive line map and negative line map (our format) + num_junctions = junctions.shape[0] + line_map_pos = self.convert_line_map(line_pos, num_junctions) + line_map_neg = self.convert_line_map(line_neg, num_junctions) + + # Generate the line heatmap after post-processing + junctions_xy = np.flip(np.round(junctions).astype(np.int32), axis=1) + # Update image size + image_size = image.shape[:2] + heatmap_pos = get_line_heatmap(junctions_xy, line_map_pos, image_size) + heatmap_neg = get_line_heatmap(junctions_xy, line_map_neg, image_size) + # Declare default valid mask (all ones) + valid_mask = np.ones(image_size) + + # Optionally convert the image to grayscale + if self.config["gray_scale"]: + image = (color.rgb2gray(image) * 255.).astype(np.uint8) + + # Check if we need to apply augmentations + # In training mode => yes. + # In homography adaptation mode (export mode) => No + if self.config["augmentation"]["photometric"]["enable"]: + photo_trans_lst = self.get_photo_transform() + ### Image transform ### + np.random.shuffle(photo_trans_lst) + image_transform = transforms.Compose( + photo_trans_lst + [photoaug.normalize_image()]) + else: + image_transform = photoaug.normalize_image() + image = image_transform(image) + + # Check homographic augmentation + if self.config["augmentation"]["homographic"]["enable"]: + homo_trans = self.get_homo_transform() + # Perform homographic transform + outputs_pos = homo_trans(image, junctions, line_map_pos) + outputs_neg = homo_trans(image, junctions, line_map_neg) + + # record the warped results + junctions = outputs_pos["junctions"] # Should be HW format + image = outputs_pos["warped_image"] + line_map_pos = outputs_pos["line_map"] + line_map_neg = outputs_neg["line_map"] + heatmap_pos = outputs_pos["warped_heatmap"] + heatmap_neg = outputs_neg["warped_heatmap"] + valid_mask = outputs_pos["valid_mask"] # Same for pos and neg + + junction_map = self.junc_to_junc_map(junctions, image_size) + + # Convert to tensor and return the results + to_tensor = transforms.ToTensor() + if not numpy: + return { + "image": to_tensor(image), + "junctions": to_tensor(junctions).to(torch.float32)[0, ...], + "junction_map": to_tensor(junction_map).to(torch.int), + "line_map_pos": to_tensor( + line_map_pos).to(torch.int32)[0, ...], + "line_map_neg": to_tensor( + line_map_neg).to(torch.int32)[0, ...], + "heatmap_pos": to_tensor(heatmap_pos).to(torch.int32), + "heatmap_neg": to_tensor(heatmap_neg).to(torch.int32), + "valid_mask": to_tensor(valid_mask).to(torch.int32) + } + else: + return { + "image": image, + "junctions": junctions.astype(np.float32), + "junction_map": junction_map.astype(np.int32), + "line_map_pos": line_map_pos.astype(np.int32), + "line_map_neg": line_map_neg.astype(np.int32), + "heatmap_pos": heatmap_pos.astype(np.int32), + "heatmap_neg": heatmap_neg.astype(np.int32), + "valid_mask": valid_mask.astype(np.int32) + } + + def train_preprocessing_exported( + self, data, numpy=False, disable_homoaug=False, + desc_training=False, H1=None, H1_scale=None, H2=None, scale=1., + h_crop=None, w_crop=None): + """ Train preprocessing for the exported labels. """ + data = copy.deepcopy(data) + # Fetch the corresponding entries + image = data["image"] + junctions = data["junctions"] + line_map = data["line_map"] + image_size = image.shape[:2] + + # Define the random crop for scaling if necessary + if h_crop is None or w_crop is None: + h_crop, w_crop = 0, 0 + if scale > 1: + H, W = self.config["preprocessing"]["resize"] + H_scale, W_scale = round(H * scale), round(W * scale) + if H_scale > H: + h_crop = np.random.randint(H_scale - H) + if W_scale > W: + w_crop = np.random.randint(W_scale - W) + + # Resize the image before photometric and homographical augmentations + if not(list(image_size) == self.config["preprocessing"]["resize"]): + # Resize the image and the point location. + size_old = list(image.shape)[:2] # Only H and W dimensions + + image = cv2.resize( + image, tuple(self.config['preprocessing']['resize'][::-1]), + interpolation=cv2.INTER_LINEAR) + image = np.array(image, dtype=np.uint8) + + # In HW format + junctions = (junctions * np.array( + self.config['preprocessing']['resize'], np.float) + / np.array(size_old, np.float)) + + # Generate the line heatmap after post-processing + junctions_xy = np.flip(np.round(junctions).astype(np.int32), axis=1) + image_size = image.shape[:2] + heatmap = get_line_heatmap(junctions_xy, line_map, image_size) + + # Optionally convert the image to grayscale + if self.config["gray_scale"]: + image = (color.rgb2gray(image) * 255.).astype(np.uint8) + + # Check if we need to apply augmentations + # In training mode => yes. + # In homography adaptation mode (export mode) => No + if self.config["augmentation"]["photometric"]["enable"]: + photo_trans_lst = self.get_photo_transform() + ### Image transform ### + np.random.shuffle(photo_trans_lst) + image_transform = transforms.Compose( + photo_trans_lst + [photoaug.normalize_image()]) + else: + image_transform = photoaug.normalize_image() + image = image_transform(image) + + # Perform the random scaling + if scale != 1.: + image, junctions, line_map, valid_mask = random_scaling( + image, junctions, line_map, scale, + h_crop=h_crop, w_crop=w_crop) + else: + # Declare default valid mask (all ones) + valid_mask = np.ones(image_size) + + # Initialize the empty output dict + outputs = {} + # Convert to tensor and return the results + to_tensor = transforms.ToTensor() + + # Check homographic augmentation + warp = (self.config["augmentation"]["homographic"]["enable"] + and disable_homoaug == False) + if warp: + homo_trans = self.get_homo_transform() + # Perform homographic transform + if H1 is None: + homo_outputs = homo_trans( + image, junctions, line_map, valid_mask=valid_mask) + else: + homo_outputs = homo_trans( + image, junctions, line_map, homo=H1, scale=H1_scale, + valid_mask=valid_mask) + homography_mat = homo_outputs["homo"] + + # Give the warp of the other view + if H1 is None: + H1 = homo_outputs["homo"] + + # Sample points along each line segments for the descriptor + if desc_training: + line_points, line_indices = self.get_line_points( + junctions, line_map, H1=H1, H2=H2, + img_size=image_size, warp=warp) + + # Record the warped results + if warp: + junctions = homo_outputs["junctions"] # Should be HW format + image = homo_outputs["warped_image"] + line_map = homo_outputs["line_map"] + valid_mask = homo_outputs["valid_mask"] # Same for pos and neg + # heatmap = homo_outputs["warped_heatmap"] + + # Optionally put warping information first. + if not numpy: + outputs["homography_mat"] = to_tensor( + homography_mat).to(torch.float32)[0, ...] + else: + outputs["homography_mat"] = homography_mat.astype(np.float32) + + junction_map = self.junc_to_junc_map(junctions, image_size) + + if not numpy: + outputs.update({ + "image": to_tensor(image).to(torch.float32), + "junctions": to_tensor(junctions).to(torch.float32)[0, ...], + "junction_map": to_tensor(junction_map).to(torch.int), + "line_map": to_tensor(line_map).to(torch.int32)[0, ...], + # "heatmap": to_tensor(heatmap).to(torch.int32), + "valid_mask": to_tensor(valid_mask).to(torch.int32) + }) + if desc_training: + outputs.update({ + "line_points": to_tensor( + line_points).to(torch.float32)[0], + "line_indices": torch.tensor(line_indices, + dtype=torch.int) + }) + else: + outputs.update({ + "image": image, + "junctions": junctions.astype(np.float32), + "junction_map": junction_map.astype(np.int32), + "line_map": line_map.astype(np.int32), + # "heatmap": heatmap.astype(np.int32), + "valid_mask": valid_mask.astype(np.int32) + }) + if desc_training: + outputs.update({ + "line_points": line_points.astype(np.float32), + "line_indices": line_indices.astype(int) + }) + + return outputs + + def preprocessing_exported_paired_desc(self, data, numpy=False, scale=1.): + """ Train preprocessing for paired data for the exported labels + for descriptor training. """ + outputs = {} + + # Define the random crop for scaling if necessary + h_crop, w_crop = 0, 0 + if scale > 1: + H, W = self.config["preprocessing"]["resize"] + H_scale, W_scale = round(H * scale), round(W * scale) + if H_scale > H: + h_crop = np.random.randint(H_scale - H) + if W_scale > W: + w_crop = np.random.randint(W_scale - W) + + # Sample ref homography first + homo_config = self.config["augmentation"]["homographic"]["params"] + image_shape = self.config["preprocessing"]["resize"] + ref_H, ref_scale = homoaug.sample_homography(image_shape, + **homo_config) + + # Data for target view (All augmentation) + target_data = self.train_preprocessing_exported( + data, numpy=numpy, desc_training=True, H1=None, H2=ref_H, + scale=scale, h_crop=h_crop, w_crop=w_crop) + + # Data for reference view (No homographical augmentation) + ref_data = self.train_preprocessing_exported( + data, numpy=numpy, desc_training=True, H1=ref_H, + H1_scale=ref_scale, H2=target_data["homography_mat"].numpy(), + scale=scale, h_crop=h_crop, w_crop=w_crop) + + # Spread ref data + for key, val in ref_data.items(): + outputs["ref_" + key] = val + + # Spread target data + for key, val in target_data.items(): + outputs["target_" + key] = val + + return outputs + + def test_preprocessing(self, data, numpy=False): + """ Test preprocessing for GT data. """ + data = copy.deepcopy(data) + # Fetch the corresponding entries + image = data["image"] + junctions = data["junc"][:, :2] + line_pos = data["Lpos"] + line_neg = data["Lneg"] + image_size = image.shape[:2] + # Convert junctions to pixel coordinates (from 128x128) + junctions[:, 0] *= image_size[0] / 128 + junctions[:, 1] *= image_size[1] / 128 + + # Resize the image before photometric and homographical augmentations + if not(list(image_size) == self.config["preprocessing"]["resize"]): + # Resize the image and the point location. + size_old = list(image.shape)[:2] # Only H and W dimensions + + image = cv2.resize( + image, tuple(self.config['preprocessing']['resize'][::-1]), + interpolation=cv2.INTER_LINEAR) + image = np.array(image, dtype=np.uint8) + + # In HW format + junctions = (junctions * np.array( + self.config['preprocessing']['resize'], np.float) + / np.array(size_old, np.float)) + + # Optionally convert the image to grayscale + if self.config["gray_scale"]: + image = (color.rgb2gray(image) * 255.).astype(np.uint8) + + # Still need to normalize image + image_transform = photoaug.normalize_image() + image = image_transform(image) + + # Convert to positive line map and negative line map (our format) + num_junctions = junctions.shape[0] + line_map_pos = self.convert_line_map(line_pos, num_junctions) + line_map_neg = self.convert_line_map(line_neg, num_junctions) + + # Generate the line heatmap after post-processing + junctions_xy = np.flip(np.round(junctions).astype(np.int32), axis=1) + # Update image size + image_size = image.shape[:2] + heatmap_pos = get_line_heatmap(junctions_xy, line_map_pos, image_size) + heatmap_neg = get_line_heatmap(junctions_xy, line_map_neg, image_size) + # Declare default valid mask (all ones) + valid_mask = np.ones(image_size) + + junction_map = self.junc_to_junc_map(junctions, image_size) + + # Convert to tensor and return the results + to_tensor = transforms.ToTensor() + if not numpy: + return { + "image": to_tensor(image), + "junctions": to_tensor(junctions).to(torch.float32)[0, ...], + "junction_map": to_tensor(junction_map).to(torch.int), + "line_map_pos": to_tensor( + line_map_pos).to(torch.int32)[0, ...], + "line_map_neg": to_tensor( + line_map_neg).to(torch.int32)[0, ...], + "heatmap_pos": to_tensor(heatmap_pos).to(torch.int32), + "heatmap_neg": to_tensor(heatmap_neg).to(torch.int32), + "valid_mask": to_tensor(valid_mask).to(torch.int32) + } + else: + return { + "image": image, + "junctions": junctions.astype(np.float32), + "junction_map": junction_map.astype(np.int32), + "line_map_pos": line_map_pos.astype(np.int32), + "line_map_neg": line_map_neg.astype(np.int32), + "heatmap_pos": heatmap_pos.astype(np.int32), + "heatmap_neg": heatmap_neg.astype(np.int32), + "valid_mask": valid_mask.astype(np.int32) + } + + def test_preprocessing_exported(self, data, numpy=False, scale=1.): + """ Test preprocessing for the exported labels. """ + data = copy.deepcopy(data) + # Fetch the corresponding entries + image = data["image"] + junctions = data["junctions"] + line_map = data["line_map"] + image_size = image.shape[:2] + + # Resize the image before photometric and homographical augmentations + if not(list(image_size) == self.config["preprocessing"]["resize"]): + # Resize the image and the point location. + size_old = list(image.shape)[:2] # Only H and W dimensions + + image = cv2.resize( + image, tuple(self.config['preprocessing']['resize'][::-1]), + interpolation=cv2.INTER_LINEAR) + image = np.array(image, dtype=np.uint8) + + # In HW format + junctions = (junctions * np.array( + self.config['preprocessing']['resize'], np.float) + / np.array(size_old, np.float)) + + # Optionally convert the image to grayscale + if self.config["gray_scale"]: + image = (color.rgb2gray(image) * 255.).astype(np.uint8) + + # Still need to normalize image + image_transform = photoaug.normalize_image() + image = image_transform(image) + + # Generate the line heatmap after post-processing + junctions_xy = np.flip(np.round(junctions).astype(np.int32), axis=1) + image_size = image.shape[:2] + heatmap = get_line_heatmap(junctions_xy, line_map, image_size) + + # Declare default valid mask (all ones) + valid_mask = np.ones(image_size) + + junction_map = self.junc_to_junc_map(junctions, image_size) + + # Convert to tensor and return the results + to_tensor = transforms.ToTensor() + if not numpy: + outputs = { + "image": to_tensor(image), + "junctions": to_tensor(junctions).to(torch.float32)[0, ...], + "junction_map": to_tensor(junction_map).to(torch.int), + "line_map": to_tensor(line_map).to(torch.int32)[0, ...], + "heatmap": to_tensor(heatmap).to(torch.int32), + "valid_mask": to_tensor(valid_mask).to(torch.int32) + } + else: + outputs = { + "image": image, + "junctions": junctions.astype(np.float32), + "junction_map": junction_map.astype(np.int32), + "line_map": line_map.astype(np.int32), + "heatmap": heatmap.astype(np.int32), + "valid_mask": valid_mask.astype(np.int32) + } + + return outputs + + def __len__(self): + return self.dataset_length + + def get_data_from_key(self, file_key): + """ Get data from file_key. """ + # Check key exists + if not file_key in self.filename_dataset.keys(): + raise ValueError("[Error] the specified key is not in the dataset.") + + # Get the data paths + data_path = self.filename_dataset[file_key] + # Read in the image and npz labels (but haven't applied any transform) + data = self.get_data_from_path(data_path) + + # Perform transform and augmentation + if self.mode == "train" or self.config["add_augmentation_to_all_splits"]: + data = self.train_preprocessing(data, numpy=True) + else: + data = self.test_preprocessing(data, numpy=True) + + # Add file key to the output + data["file_key"] = file_key + + return data + + def __getitem__(self, idx): + """Return data + file_key: str, keys used to retrieve data from the filename dataset. + image: torch.float, C*H*W range 0~1, + junctions: torch.float, N*2, + junction_map: torch.int32, 1*H*W range 0 or 1, + line_map_pos: torch.int32, N*N range 0 or 1, + line_map_neg: torch.int32, N*N range 0 or 1, + heatmap_pos: torch.int32, 1*H*W range 0 or 1, + heatmap_neg: torch.int32, 1*H*W range 0 or 1, + valid_mask: torch.int32, 1*H*W range 0 or 1 + """ + # Get the corresponding datapoint and contents from filename dataset + file_key = self.datapoints[idx] + data_path = self.filename_dataset[file_key] + # Read in the image and npz labels (but haven't applied any transform) + data = self.get_data_from_path(data_path) + # Also load the exported labels if not using the official ground truth + if not self.gt_source == "official": + with h5py.File(self.gt_source, "r") as f: + exported_label = parse_h5_data(f[file_key]) + + data["junctions"] = exported_label["junctions"] + data["line_map"] = exported_label["line_map"] + data["junctions"] = data["junctions"][:,::-1] + # Perform transform and augmentation + return_type = self.config.get("return_type", "single") + if (self.mode == "train" + or self.config["add_augmentation_to_all_splits"]): + # Perform random scaling first + if self.config["augmentation"]["random_scaling"]["enable"]: + scale_range = self.config["augmentation"]["random_scaling"]["range"] + # Decide the scaling + scale = np.random.uniform(min(scale_range), max(scale_range)) + else: + scale = 1. + if self.gt_source == "official": + data = self.train_preprocessing(data) + else: + if return_type == "paired_desc": + data = self.preprocessing_exported_paired_desc( + data, scale=scale) + else: + data = self.train_preprocessing_exported(data, + scale=scale) + else: + if self.gt_source == "official": + data = self.test_preprocessing(data) + elif return_type == "paired_desc": + data = self.preprocessing_exported_paired_desc(data) + else: + data = self.test_preprocessing_exported(data) + + # Add file key to the output + data["file_key"] = file_key + + return data + + ######################## + ## Some other methods ## + ######################## + def _check_dataset_cache(self): + """ Check if dataset cache exists. """ + cache_file_path = os.path.join(self.cache_path, self.cache_name) + if os.path.exists(cache_file_path): + return True + else: + return False + + # Check if the repeatability cache dataset exists + def _check_rep_eval_dataset_cache(self, split): + if split == "i": + cache_file_path = os.path.join(self.cache_path, self.rep_i_cache_name) + else: + cache_file_path = os.path.join(self.cache_path, self.rep_v_cache_name) + + return os.path.exists(cache_file_path) \ No newline at end of file diff --git a/scalelsd/ssl/datasets/yorkurban_dataset.py b/scalelsd/ssl/datasets/yorkurban_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..de6c460957fe4b2d7f7c9c9f16af9ecc04e05b60 --- /dev/null +++ b/scalelsd/ssl/datasets/yorkurban_dataset.py @@ -0,0 +1,479 @@ +from torch.utils.data import Dataset +import os +import math +from tqdm import tqdm +from skimage.io import imread +from skimage import color +import PIL +import numpy as np +import h5py +import cv2 +import pickle +from .synthetic_util import get_line_heatmap +from torchvision import transforms +import torch +import torch.utils.data.dataloader as torch_loader +# Augmentation libs +from ..config.project_config import Config as cfg +from .transforms import photometric_transforms as photoaug +from .transforms import homographic_transforms as homoaug +# Some visualization tools +from ..misc.visualize_util import plot_junctions, plot_line_segments +# Some data parsing tools +from ..misc.train_utils import parse_h5_data +# Inherit from private dataset +# from dataset.private_dataset import PrivateDataset + + +# Implements the customized collate_fn for yorkurban dataset +def yorkurban_collate_fn(batch): + batch_keys = ["image", "junction_map", "valid_mask", "heatmap", + "heatmap_pos", "heatmap_neg", "homography"] + list_keys = ["junctions", "line_map", "line_map_pos", "line_map_neg", "file_key", + "aux_junctions", "aux_line_map"] + + outputs = {} + for data_key in batch[0].keys(): + batch_match = sum([_ in data_key for _ in batch_keys]) + list_match = sum([_ in data_key for _ in list_keys]) + # print(batch_match, list_match) + if batch_match > 0 and list_match == 0: + outputs[data_key] = torch_loader.default_collate([b[data_key] for b in batch]) + elif batch_match == 0 and list_match > 0: + outputs[data_key] = [b[data_key] for b in batch] + elif batch_match == 0 and list_match == 0: + continue + else: + raise ValueError("[Error] A key matches batch keys and list keys simultaneously.") + + return outputs + +# The processed wireframe. +class YorkUrbanDataset(Dataset): + # Initialize the dataset + def __init__(self, mode="test", config=None): + super(YorkUrbanDataset, self).__init__() + # Check mode => "train", "val", "test + if not mode in ["test"]: + raise ValueError("[Error] Unknown mode for york urban dataset. Only 'test' mode is available.") + self.mode = mode + + self.config = config + + # Get cache setting + self.dataset_name = self.get_dataset_name() + self.cache_name = self.get_cache_name() + self.cache_path = cfg.yorkurban_cache_path + + # Get the filename dataset + print("[Info] Initializing york urban dataset...") + self.filename_dataset, self.datapoints = self.get_filename_dataset() + # Get dataset length + self.dataset_length = len(self.datapoints) + + # Get repeatability evaluation set + if self.mode == "test" and self.config.get("evaluation", None) is not None: + # Get the cache name + tmp = self.cache_name.split(self.mode) + self.rep_i_cache_name = tmp[0] + self.mode + "_rep_i" + tmp[1] + self.rep_v_cache_name = tmp[0] + self.mode + "_rep_v" + tmp[1] + + # Get the repeatability config + self.rep_config = self.config["evaluation"]["repeatability"] + + self.rep_eval_dataset = self.construct_rep_eval_dataset() + self.rep_eval_datapoints = self.get_rep_eval_datapoints() + + # Print some info + print("[Info] Successfully initialized dataset") + print("\t Name: yorkurban") + print("\t Mode: %s" %(self.mode)) + print("\t Gt: %s" %(self.config.get("gt_source_%s"%(self.mode), "official"))) + print("\t Counts: %d" %(self.dataset_length)) + print("----------------------------------------") + + def get_filename_dataset(self): + # Get the path to the dataset + if self.mode == "train": + raise NotImplementedError + elif self.mode == "test": + dataset_path = os.path.join(cfg.yorkurban_dataroot) + # Get paths to all image files + folder_lst = sorted([os.path.join(dataset_path, _) for _ in os.listdir(dataset_path) \ + if os.path.isdir(os.path.join(dataset_path, _))]) + folder_lst = folder_lst[:-1] + #folder_lst = [f for f in folder_lst if f.startswith('P')] + image_paths = [] + for folder in folder_lst: + image_path = [os.path.join(folder, _) for _ in os.listdir(folder) \ + if os.path.splitext(_)[-1] == ".jpg" or os.path.splitext(_)[-1] == ".png"] + image_paths += image_path + + # Verify all the images and labels exist + for idx in range(len(image_paths)): + image_path = image_paths[idx] + if not os.path.exists(image_path): + raise ValueError("[Error] The image does not exist. %s"%(image_path)) + + # Construct the filename dataset + num_pad = int(math.ceil(math.log10(len(image_paths))) + 1) + filename_dataset = {} + for idx in range(len(image_paths)): + # Get the file key + key = self.get_padded_filename(num_pad, idx) + + filename_dataset[key] = { + "image": image_paths[idx] + } + + # Get the datapoints + datapoints = list(sorted(filename_dataset.keys())) + + return filename_dataset, datapoints + + # Get the padded filename using adaptive padding + @staticmethod + def get_padded_filename(num_pad, idx): + file_len = len("%d" % (idx)) + filename = "0" * (num_pad - file_len) + "%d" % (idx) + + return filename + + # Get dataset name from dataset config / default config + def get_dataset_name(self): + if self.config["dataset_name"] is None: + dataset_name = "yorkurban_dataset" + f"_{self.mode}" + else: + dataset_name = self.config["dataset_name"] + f"_{self.mode}" + + return dataset_name + + # Get cache name from dataset config / default config + def get_cache_name(self): + if self.config["dataset_name"] is None: + dataset_name = "yorkurban_dataset" + f"_{self.mode}" + else: + dataset_name = self.config["dataset_name"] + f"_{self.mode}" + # Compose cache name + cache_name = dataset_name + "_cache.pkl" + + return cache_name + + ########################################### + ## Repeatability evaluation related APIs ## + ########################################### + # Construct repeatability evaluation dataset (from scratch or from cache) + def construct_rep_eval_dataset(self): + rep_eval_dataset = {} + # Check if viewpoint and illumination cache exists + if self.rep_config["photometric"]["enable"]: + if self._check_rep_eval_dataset_cache(split="i"): + print("\t Found repeatability illumination cache %s at %s"%(self.rep_i_cache_name, self.cache_path)) + print("\t Load repeatability illumination cache...") + rep_i_keymap, rep_i_dataset_name = self.get_rep_eval_dataset_from_cache(split="i") + else: + print("\t Can't find repeatability illumination cache ...") + print("\t Create repeatability illumination dataset from scratch...") + rep_i_keymap, rep_i_dataset_name = self.get_rep_eval_dataset(split="i") + print("\t Create filename dataset cache...") + self.create_rep_eval_dataset_cache("i", rep_i_keymap, rep_i_dataset_name) + else: + rep_i_keymap = None + rep_i_dataset_name = None + + rep_eval_dataset["illumination"] = { + "keymap": rep_i_keymap, + "dataset_name": rep_i_dataset_name + } + + if self.rep_config["homographic"]["enable"]: + if self._check_rep_eval_dataset_cache(split="v"): + print("\t Found repeatability viewpoint cache %s at %s"%(self.rep_v_cache_name, self.cache_path)) + print("\t Load repeatability viewpoint cache...") + rep_v_keymap, rep_v_dataset_name = self.get_rep_eval_dataset_from_cache(split="v") + else: + print("\t Can't find repeatability viewpoint cache ...") + print("\t Create repeatability viewpoint dataset from scratch...") + rep_v_keymap, rep_v_dataset_name = self.get_rep_eval_dataset(split="v") + print("\t Create filename dataset cache...") + self.create_rep_eval_dataset_cache("v", rep_v_keymap, rep_v_dataset_name) + else: + rep_v_keymap = None + rep_v_dataset_name = None + + rep_eval_dataset["viewpoint"] = { + "keymap": rep_v_keymap, + "dataset_name": rep_v_dataset_name + } + + return rep_eval_dataset + + # Create filename dataset cache for faster initialization + def create_rep_eval_dataset_cache(self, split, keymap, dataset_name): + # Check cache path exists + if not os.path.exists(self.cache_path): + os.makedirs(self.cache_path) + + if split == "i": + cache_file_path = os.path.join(self.cache_path, self.rep_i_cache_name) + elif split == "v": + cache_file_path = os.path.join(self.cache_path, self.rep_v_cache_name) + else: + raise ValueError("[Error] Unknown split for repeatability evaluation.") + + data = { + "keymap": keymap, + "dataset_name": dataset_name + } + with open(cache_file_path, "wb") as f: + pickle.dump(data, f, pickle.HIGHEST_PROTOCOL) + + # Get filename dataset from cache + def get_rep_eval_dataset_from_cache(self, split): + # Load from pkl cache + if split == "i": + cache_file_path = os.path.join(self.cache_path, self.rep_i_cache_name) + elif split == "v": + cache_file_path = os.path.join(self.cache_path, self.rep_v_cache_name) + else: + raise ValueError("[Error] Unknown split for repeatability evaluation.") + + with open(cache_file_path, "rb") as f: + data = pickle.load(f) + + return data["keymap"], data["dataset_name"] + + # Initialize the repeatability evaluation dataset from scratch + def get_rep_eval_dataset(self, split): + image_shape = self.config["preprocessing"]["resize"] + + # Initialize the illumination set + if split == "i": + # Set the random seed before continuing + seed = self.rep_config["seed"] + np.random.seed(seed) + torch.manual_seed(seed) + + raise NotImplementedError + + # Initialize the viewpoint set + elif split == "v": + # Set the random seed before continuing + seed = self.rep_config["seed"] + np.random.seed(seed) + torch.manual_seed(seed) + + v_keymap = {} + # Get the name for the output h5 dataset + v_dataset_name = self.rep_v_cache_name.split(".pkl")[0] + ".h5" + v_dataset_path = os.path.join(self.cache_path, v_dataset_name) + with h5py.File(v_dataset_path, "w") as f: + # Iterate through all the file_key in test set + for idx, key in enumerate(tqdm(list(self.filename_dataset.keys()), ascii=True)): + # Sample N random homography + file_key_lst = [] + for i in range(self.rep_config["homographic"]["num_samples"]): + file_key = key + "_" + str(i) + + # Sample a random homography + homo_mat, _ = homoaug.sample_homography(image_shape, + **self.rep_config["homographic"]["params"]) + + file_key_lst.append(file_key) + f.create_dataset(file_key, data=homo_mat, compression="gzip") + + v_keymap[key] = file_key_lst + + return v_keymap, v_dataset_name + + else: + raise ValueError("[Error] Unknow split for repeatability evaluation.") + + # Convert ref image and warped images to list of evaluation pairs + def get_rep_eval_datapoints(self): + datapoints = { + "illumination": [], + "viewpoint": [] + } + + # Iterate through all the ref image + if self.rep_eval_dataset["illumination"]["keymap"] is not None: + for ref_key in sorted(self.rep_eval_dataset["illumination"]["keymap"].keys()): + pair_lst = [[ref_key, _] for _ in self.rep_eval_dataset["illumination"]["keymap"][ref_key]] + datapoints["illumination"] += pair_lst + + if self.rep_eval_dataset["viewpoint"]["keymap"] is not None: + for ref_key in sorted(self.rep_eval_dataset["viewpoint"]["keymap"].keys()): + pair_lst = [[ref_key, _] for _ in self.rep_eval_dataset["viewpoint"]["keymap"][ref_key]] + datapoints["viewpoint"] += pair_lst + + return datapoints + + # Check if the repeatability cache dataset exists + def _check_rep_eval_dataset_cache(self, split): + if split == "i": + cache_file_path = os.path.join(self.cache_path, self.rep_i_cache_name) + else: + cache_file_path = os.path.join(self.cache_path, self.rep_v_cache_name) + + return os.path.exists(cache_file_path) + + ########################################### + ## Repeatability evaluation related APIs ## + ########################################### + # Get the corresponding data according to the "index in rep_eval_datapoints". + def get_rep_eval_data(self, split, idx): + assert split in ["viewpoint", "illumination"] + datapoint = self.rep_eval_datapoints[split][idx] + + # Get reference image + ref_key = datapoint[0] + # Get the data paths + data_path = self.filename_dataset[ref_key] + # Read in the image and npz labels (but haven't applied any transform) + image = imread(data_path["image"]) + + # Resize the image before photometric and homographical augmentations + image_size = image.shape[:2] + if not(list(image_size) == self.config["preprocessing"]["resize"]): + # Resize the image and the point location. + size_old = list(image.shape)[:2] # Only H and W dimensions + + image = cv2.resize(image, tuple(self.config['preprocessing']['resize'][::-1]), + interpolation=cv2.INTER_LINEAR) + image = np.array(image, dtype=np.uint8) + + # Optionally convert the image to grayscale + if self.config["gray_scale"]: + image = (color.rgb2gray(image) *255.).astype(np.uint8) + + image_transform = photoaug.normalize_image() + image = image_transform(image) + + # Get target image + if split == "viewpoint": + target_key = datapoint[1] + dataset_path = os.path.join(self.cache_path, self.rep_eval_dataset[split]["dataset_name"]) + with h5py.File(dataset_path, "r") as f: + homo_mat = np.array(f[target_key]) + + # Warp the image + target_size = (image.shape[1], image.shape[0]) + target_image = cv2.warpPerspective(image, homo_mat, target_size, + flags=cv2.INTER_LINEAR) + + else: + raise NotImplementedError + + # Convert to tensor and return the results + to_tensor = transforms.ToTensor() + + return { + "ref_image": to_tensor(image), + "ref_key": ref_key, + "target_image": to_tensor(target_image), + "target_key": target_key, + "homo_mat": homo_mat + } + + ############################################ + ## Pytorch and preprocessing related APIs ## + ############################################ + # Get the length of the dataset + def __len__(self): + return self.dataset_length + + # Get data from the information from filename dataset + @staticmethod + def get_data_from_path(data_path): + output = {} + + # Get image data + image_path = data_path["image"] + image = imread(image_path) + output["image"] = image + + return output + + # The test preprocessing + def test_preprocessing(self, data, numpy=False): + # Fetch the corresponding entries + image = data["image"] + image_size = image.shape[:2] + + # Resize the image before photometric and homographical augmentations + if not(list(image_size) == self.config["preprocessing"]["resize"]): + # Resize the image and the point location. + size_old = list(image.shape)[:2] # Only H and W dimensions + + image = cv2.resize(image, tuple(self.config['preprocessing']['resize'][::-1]), + interpolation=cv2.INTER_LINEAR) + image = np.array(image, dtype=np.uint8) + + # Optionally convert the image to grayscale + if self.config["gray_scale"]: + image = (color.rgb2gray(image) *255.).astype(np.uint8) + + # Still need to normalize image + image_transform = photoaug.normalize_image() + image = image_transform(image) + + # Update image size + image_size = image.shape[:2] + valid_mask = np.ones(image_size) + + # Convert to tensor and return the results + to_tensor = transforms.ToTensor() + if not numpy: + return { + "image": to_tensor(image), + "valid_mask": to_tensor(valid_mask).to(torch.int32) + } + else: + return { + "image": image, + "valid_mask": valid_mask.astype(np.int32) + } + + # Define the getitem method + def __getitem__(self, idx): + """Return data + file_key: str, keys used to retrieve certain data from the filename dataset. + image: torch.float, C*H*W range 0~1, + valid_mask: torch.int32, 1*H*W range 0 or 1 + """ + # Get the corresponding datapoint and get contents from filename dataset + file_key = self.datapoints[idx] + data_path = self.filename_dataset[file_key] + # Read in the image and npz labels (but haven't applied any transform) + data = self.get_data_from_path(data_path) + + # Perform transform and augmentation + if self.mode == "train" or self.config["add_augmentation_to_all_splits"]: + raise NotImplementedError + else: + data = self.test_preprocessing(data) + + # Add file key to the output + data["file_key"] = file_key + + return data + +if __name__ == "__main__": + import sys + import yaml + import matplotlib + import matplotlib.pyplot as plt + plt.switch_backend("TkAgg") + from torch.utils.data import DataLoader + sys.path.append("../") + + # Load configuration file + with open("./config/yorkurban_dataset_config.yaml", "r") as f: + config = yaml.safe_load(f) + + config["add_augmentation_to_all_splits"] = False + + # Initialize the dataset + test_dataset = YorkUrbanDataset(mode="test", config=config) + import ipdb; ipdb.set_trace() \ No newline at end of file diff --git a/scalelsd/ssl/misc/__init__.py b/scalelsd/ssl/misc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scalelsd/ssl/misc/geometry_utils.py b/scalelsd/ssl/misc/geometry_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7d453be93a9f4e443da8f69efd3753b94354d381 --- /dev/null +++ b/scalelsd/ssl/misc/geometry_utils.py @@ -0,0 +1,282 @@ +import numpy as np +import torch + +from shapely.geometry.polygon import LinearRing + +try: + from pycolmap import image_to_world, world_to_image +except: + pass + +### Point-related utils + +# Warp a list of points using a homography +def warp_points(points, homography): + # Convert to homogeneous and in xy format + new_points = np.concatenate([points[..., [1, 0]], + np.ones_like(points[..., :1])], axis=-1) + # Warp + new_points = (homography @ new_points.T).T + # Convert back to inhomogeneous and hw format + new_points = new_points[..., [1, 0]] / new_points[..., 2:] + return new_points + + +# Mask out the points that are outside of img_size +def mask_points(points, img_size): + mask = ((points[..., 0] >= 0) + & (points[..., 0] < img_size[0]) + & (points[..., 1] >= 0) + & (points[..., 1] < img_size[1])) + return mask + + +# Convert a tensor [N, 2] or batched tensor [B, N, 2] of N keypoints into +# a grid in [-1, 1]Β² that can be used in torch.nn.functional.interpolate +def keypoints_to_grid(keypoints, img_size): + n_points = keypoints.size()[-2] + device = keypoints.device + grid_points = keypoints.float() * 2. / torch.tensor( + img_size, dtype=torch.float, device=device) - 1. + grid_points = grid_points[..., [1, 0]].view(-1, n_points, 1, 2) + return grid_points + + +# Return a 2D matrix indicating the local neighborhood of each point +# for a given threshold and two lists of corresponding keypoints +def get_dist_mask(kp0, kp1, valid_mask, dist_thresh): + b_size, n_points, _ = kp0.size() + dist_mask0 = torch.norm(kp0.unsqueeze(2) - kp0.unsqueeze(1), dim=-1) + dist_mask1 = torch.norm(kp1.unsqueeze(2) - kp1.unsqueeze(1), dim=-1) + dist_mask = torch.min(dist_mask0, dist_mask1) + dist_mask = dist_mask <= dist_thresh + dist_mask = dist_mask.repeat(1, 1, b_size).reshape(b_size * n_points, + b_size * n_points) + dist_mask = dist_mask[valid_mask, :][:, valid_mask] + return dist_mask + + +### Line-related utils + +# Sample n points along lines of shape (num_lines, 2, 2) +def sample_line_points(lines, n): + line_points_x = np.linspace(lines[:, 0, 0], lines[:, 1, 0], n, axis=-1) + line_points_y = np.linspace(lines[:, 0, 1], lines[:, 1, 1], n, axis=-1) + line_points = np.stack([line_points_x, line_points_y], axis=2) + return line_points + + +# Return a mask of the valid lines that are within a valid mask of an image +def mask_lines(lines, valid_mask): + h, w = valid_mask.shape + int_lines = np.clip(np.round(lines).astype(int), 0, [h - 1, w - 1]) + h_valid = valid_mask[int_lines[:, 0, 0], int_lines[:, 0, 1]] + w_valid = valid_mask[int_lines[:, 1, 0], int_lines[:, 1, 1]] + valid = h_valid & w_valid + return valid + + +# Return a 2D matrix indicating for each pair of points +# if they are on the same line or not +def get_common_line_mask(line_indices, valid_mask): + b_size, n_points = line_indices.shape + common_mask = line_indices[:, :, None] == line_indices[:, None, :] + common_mask = common_mask.repeat(1, 1, b_size).reshape(b_size * n_points, + b_size * n_points) + common_mask = common_mask[valid_mask, :][:, valid_mask] + return common_mask + + +# Compute the distances between two sets of lines using the sAP distance +def get_sAP_line_distance(warped_ref_line_seg, target_line_seg): + dist = (((warped_ref_line_seg[:, None, :, None] + - target_line_seg[:, None]) ** 2).sum(-1)) ** 0.5 + dist = np.minimum( + dist[:, :, 0, 0] + dist[:, :, 1, 1], + dist[:, :, 0, 1] + dist[:, :, 1, 0] + ) + return dist + + +# Given a list of line segments and a list of points (2D or 3D coordinates), +# compute the orthogonal projection of all points on all lines. +# This returns the 1D coordinates of the projection on the line, +# as well as the list of orthogonal distances. +def project_point_to_line(line_segs, points): + # Compute the 1D coordinate of the points projected on the line + dir_vec = (line_segs[:, 1] - line_segs[:, 0])[:, None] + coords1d = (((points[None] - line_segs[:, None, 0]) * dir_vec).sum(axis=2) + / np.linalg.norm(dir_vec, axis=2) ** 2) + # coords1d is of shape (n_lines, n_points) + + # Compute the orthogonal distance of the points to each line + projection = line_segs[:, None, 0] + coords1d[:, :, None] * dir_vec + dist_to_line = np.linalg.norm(projection - points[None], axis=2) + + return coords1d, dist_to_line + + +# Given a list of segments parameterized by the 1D coordinate of the endpoints +# compute the overlap with the segment [0, 1] +def get_segment_overlap(seg_coord1d): + seg_coord1d = np.sort(seg_coord1d, axis=-1) + overlap = ((seg_coord1d[..., 1] > 0) * (seg_coord1d[..., 0] < 1) + * (np.minimum(seg_coord1d[..., 1], 1) + - np.maximum(seg_coord1d[..., 0], 0))) + return overlap + + +# Compute the symmetrical orthogonal line distance between two sets of lines +# and the average overlapping ratio of both lines. +# Enforce a high line distance for small overlaps. +# This is compatible for nD objects (e.g. both lines in 2D or 3D). +def get_overlap_orth_line_dist(line_seg1, line_seg2, min_overlap=0.5): + n_lines1, n_lines2 = len(line_seg1), len(line_seg2) + + # Compute the average orthogonal line distance + coords_2_on_1, line_dists2 = project_point_to_line( + line_seg1, line_seg2.reshape(n_lines2 * 2, -1)) + line_dists2 = line_dists2.reshape(n_lines1, n_lines2, 2).sum(axis=2) + coords_1_on_2, line_dists1 = project_point_to_line( + line_seg2, line_seg1.reshape(n_lines1 * 2, -1)) + line_dists1 = line_dists1.reshape(n_lines2, n_lines1, 2).sum(axis=2) + line_dists = (line_dists2 + line_dists1.T) / 2 + + # Compute the average overlapping ratio + coords_2_on_1 = coords_2_on_1.reshape(n_lines1, n_lines2, 2) + overlaps1 = get_segment_overlap(coords_2_on_1) + coords_1_on_2 = coords_1_on_2.reshape(n_lines2, n_lines1, 2) + overlaps2 = get_segment_overlap(coords_1_on_2).T + overlaps = (overlaps1 + overlaps2) / 2 + + # Enforce a max line distance for line segments with small overlap + low_overlaps = overlaps < min_overlap + line_dists[low_overlaps] = np.amax(line_dists) + return line_dists + + +### 3D geometry utils + +# Convert from quaternions to rotation matrix +def qvec2rotmat(qvec): + return np.array([ + [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, + 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], + 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], + [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], + 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, + 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], + [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], + 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], + 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) + + +# Convert a rotation matrix to quaternions +def rotmat2qvec(R): + Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat + K = np.array([ + [Rxx - Ryy - Rzz, 0, 0, 0], + [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], + [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], + [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 + eigvals, eigvecs = np.linalg.eigh(K) + qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] + if qvec[0] < 0: + qvec *= -1 + return qvec + + +# Read the camera intrinsics from a file in COLMAP format +def read_cameras(camera_file, scale_factor=None): + with open(camera_file, 'r') as f: + raw_cameras = f.read().rstrip().split('\n') + raw_cameras = raw_cameras[3:] + cameras = [] + for c in raw_cameras: + data = c.split(' ') + cameras.append({ + "model": data[1], + "width": int(data[2]), + "height": int(data[3]), + "params": np.array(list(map(float, data[4:])))}) + + # Optionally scale the intrinsics if the image are resized + if scale_factor is not None: + cameras = [scale_intrinsics(c, scale_factor) for c in cameras] + return cameras + + +# Adapt the camera intrinsics to an image resize +def scale_intrinsics(intrinsics, scale_factor): + new_intrinsics = {"model": intrinsics["model"], + "width": int(intrinsics["width"] * scale_factor + 0.5), + "height": int(intrinsics["height"] * scale_factor + 0.5) + } + params = intrinsics["params"] + # Adapt the focal length + params[:2] *= scale_factor + # Adapt the principal point + params[2:4] = (params[2:4] * scale_factor + 0.5) - 0.5 + new_intrinsics["params"] = params + return new_intrinsics + + +# Project points from 2D to 3D, in (x, y, z) format +def project_2d_to_3d(points, depth, T_local_to_world, intrinsics): + # Warp to world homogeneous coordinates + world_points = image_to_world(points[:, [1, 0]], + intrinsics)['world_points'] + world_points *= depth[:, None] + world_points = np.concatenate([world_points, depth[:, None], + np.ones((len(depth), 1))], axis=1) + + # Warp to the world coordinates + world_points = (T_local_to_world @ world_points.T).T + world_points = world_points[:, :3] / world_points[:, 3:] + return world_points + + +# Project points from 3D in (x, y, z) format to 2D +def project_3d_to_2d(points, T_world_to_local, intrinsics): + norm_points = np.concatenate([points, np.ones((len(points), 1))], axis=1) + norm_points = (T_world_to_local @ norm_points.T).T + norm_points = norm_points[:, :3] / norm_points[:, 3:] + norm_points = norm_points[:, :2] / norm_points[:, 2:] + image_points = world_to_image(norm_points, intrinsics) + image_points = np.stack(image_points['image_points'])[:, [1, 0]] + return image_points + + +### Line-ellipse intersection + +# Sample n points along ellipses, given as a list of +# tuples (x, c, a, b, theta). Then approximates the +# ellipse with the output polygon. +def ellipse_polyline(ellipses, n=100): + t = np.linspace(0, 2*np.pi, n, endpoint=False) + st = np.sin(t) + ct = np.cos(t) + result = [] + for x0, y0, a, b, angle in ellipses: + angle = np.deg2rad(angle) + sa = np.sin(angle) + ca = np.cos(angle) + p = np.empty((n, 2)) + p[:, 0] = x0 + a * ca * ct - b * sa * st + p[:, 1] = y0 + a * sa * ct + b * ca * st + result.append(p) + return result + + +# Compute the intersections between an ellipse a and a line. +def intersect_line_ellipse(a, line): + ea = LinearRing(a) + mp = ea.intersection(line) + if mp.is_empty: + return np.empty((0, 2)) + elif mp.geom_type == 'Point': + return np.array([[mp.x, mp.y]]) + elif mp.geom_type == 'MultiPoint': + return np.stack([[p.x for p in mp], [p.y for p in mp]], axis=-1) + else: + raise ValueError('Impossible geometry: ' + mp.geom_type) \ No newline at end of file diff --git a/scalelsd/ssl/misc/train_utils.py b/scalelsd/ssl/misc/train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d218c1521be2ae1abfc6e945e0f8f165c72f9241 --- /dev/null +++ b/scalelsd/ssl/misc/train_utils.py @@ -0,0 +1,56 @@ +""" +This file contains some useful functions for train / val. +""" +import os +import numpy as np +import torch +import random +from scalelsd.ssl.models.detector import ScaleLSD + + +################ +## HDF5 utils ## +################ +def parse_h5_data(h5_data): + """ Parse h5 dataset. """ + output_data = {} + for key in h5_data.keys(): + output_data[key] = np.array(h5_data[key]) + + return output_data + + +def fix_seeds(random_seed): + random.seed(random_seed) + np.random.seed(random_seed) + os.environ['PYTHONHASHSEED'] = str(random_seed) + torch.manual_seed(random_seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(random_seed) + torch.cuda.manual_seed_all(random_seed) # if use multi-GPU + + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + # torch.backends.cudnn.allow_tf32 = args.tf32 + # torch.backends.cuda.matmul.allow_tf32 = args.tf32 + # torch.backends.cudnn.deterministic = args.dtm + + # os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' + # torch.use_deterministic_algorithms(True) + + +def load_scalelsd_model(ckpt_path, device='cuda'): + """load model""" + use_layer_scale = False if 'v1' in ckpt_path else True + + model = ScaleLSD(gray_scale=True, use_layer_scale=use_layer_scale) + model = model.eval().to(device) + state_dict = torch.load(ckpt_path, map_location='cpu') + try: + model.load_state_dict(state_dict['model_state']) + except: + model.load_state_dict(state_dict) + + return model + diff --git a/scalelsd/ssl/misc/visualize_util.py b/scalelsd/ssl/misc/visualize_util.py new file mode 100644 index 0000000000000000000000000000000000000000..a72ce08c4685d7c6188e0fd5eb05937e550e6e2f --- /dev/null +++ b/scalelsd/ssl/misc/visualize_util.py @@ -0,0 +1,526 @@ +""" Organize some frequently used visualization functions. """ +import cv2 +import numpy as np +import matplotlib +import matplotlib.pyplot as plt +import copy +import seaborn as sns + + +# Plot junctions onto the image (return a separate copy) +def plot_junctions(input_image, junctions, junc_size=3, color=None): + """ + input_image: can be 0~1 float or 0~255 uint8. + junctions: Nx2 or 2xN np array. + junc_size: the size of the plotted circles. + """ + # Create image copy + image = copy.copy(input_image) + # Make sure the image is converted to 255 uint8 + if image.dtype == np.uint8: + pass + # A float type image ranging from 0~1 + elif image.dtype in [np.float32, np.float64, np.float] and image.max() <= 2.: + image = (image * 255.).astype(np.uint8) + # A float type image ranging from 0.~255. + elif image.dtype in [np.float32, np.float64, np.float] and image.mean() > 10.: + image = image.astype(np.uint8) + else: + raise ValueError("[Error] Unknown image data type. Expect 0~1 float or 0~255 uint8.") + + # Check whether the image is single channel + if len(image.shape) == 2 or ((len(image.shape) == 3) and (image.shape[-1] == 1)): + # Squeeze to H*W first + image = image.squeeze() + + # Stack to channle 3 + image = np.concatenate([image[..., None] for _ in range(3)], axis=-1) + + # Junction dimensions should be N*2 + if not len(junctions.shape) == 2: + raise ValueError("[Error] junctions should be 2-dim array.") + + # Always convert to N*2 + if junctions.shape[-1] != 2: + if junctions.shape[0] == 2: + junctions = junctions.T + else: + raise ValueError("[Error] At least one of the two dims should be 2.") + + # Round and convert junctions to int (and check the boundary) + H, W = image.shape[:2] + junctions = (np.round(junctions)).astype(np.int32) + junctions[junctions < 0] = 0 + junctions[junctions[:, 0] >= H, 0] = H-1 # (first dim) max bounded by H-1 + junctions[junctions[:, 1] >= W, 1] = W-1 # (second dim) max bounded by W-1 + + # Iterate through all the junctions + num_junc = junctions.shape[0] + if color is None: + color = (0, 255., 0) + for idx in range(num_junc): + # Fetch one junction + junc = junctions[idx, :] + cv2.circle(image, tuple(np.flip(junc)), radius=junc_size, + color=color, thickness=3) + + return image + + +# Plot line segements given junctions and line adjecent map +def plot_line_segments(input_image, junctions, line_map, junc_size=3, + color=(0, 255., 0), line_width=1, plot_survived_junc=True): + """ + input_image: can be 0~1 float or 0~255 uint8. + junctions: Nx2 or 2xN np array. + line_map: NxN np array + junc_size: the size of the plotted circles. + color: color of the line segments (can be string "random") + line_width: width of the drawn segments. + plot_survived_junc: whether we only plot the survived junctions. + """ + # Create image copy + image = copy.copy(input_image) + # Make sure the image is converted to 255 uint8 + if image.dtype == np.uint8: + pass + # A float type image ranging from 0~1 + elif image.dtype in [np.float32, np.float64, np.float] and image.max() <= 2.: + image = (image * 255.).astype(np.uint8) + # A float type image ranging from 0.~255. + elif image.dtype in [np.float32, np.float64, np.float] and image.mean() > 10.: + image = image.astype(np.uint8) + else: + raise ValueError("[Error] Unknown image data type. Expect 0~1 float or 0~255 uint8.") + + # Check whether the image is single channel + if len(image.shape) == 2 or ((len(image.shape) == 3) and (image.shape[-1] == 1)): + # Squeeze to H*W first + image = image.squeeze() + + # Stack to channle 3 + image = np.concatenate([image[..., None] for _ in range(3)], axis=-1) + + # Junction dimensions should be 2 + if not len(junctions.shape) == 2: + raise ValueError("[Error] junctions should be 2-dim array.") + + # Always convert to N*2 + if junctions.shape[-1] != 2: + if junctions.shape[0] == 2: + junctions = junctions.T + else: + raise ValueError("[Error] At least one of the two dims should be 2.") + + # line_map dimension should be 2 + if not len(line_map.shape) == 2: + raise ValueError("[Error] line_map should be 2-dim array.") + + # Color should be "random" or a list or tuple with length 3 + if color != "random": + if not (isinstance(color, tuple) or isinstance(color, list)): + raise ValueError("[Error] color should have type list or tuple.") + else: + if len(color) != 3: + raise ValueError("[Error] color should be a list or tuple with length 3.") + + # Make a copy of the line_map + line_map_tmp = copy.copy(line_map) + + # Parse line_map back to segment pairs + segments = np.zeros([0, 4]) + for idx in range(junctions.shape[0]): + # if no connectivity, just skip it + if line_map_tmp[idx, :].sum() == 0: + continue + # record the line segment + else: + for idx2 in np.where(line_map_tmp[idx, :] == 1)[0]: + p1 = np.flip(junctions[idx, :]) # Convert to xy format + p2 = np.flip(junctions[idx2, :]) # Convert to xy format + segments = np.concatenate((segments, np.array([p1[0], p1[1], p2[0], p2[1]])[None, ...]), axis=0) + + # Update line_map + line_map_tmp[idx, idx2] = 0 + line_map_tmp[idx2, idx] = 0 + + # Draw segment pairs + for idx in range(segments.shape[0]): + seg = np.round(segments[idx, :]).astype(np.int32) + # Decide the color + if color != "random": + color = tuple(color) + else: + color = tuple(np.random.rand(3,)) + cv2.line(image, tuple(seg[:2]), tuple(seg[2:]), color=color, thickness=line_width) + + # Also draw the junctions + if not plot_survived_junc: + num_junc = junctions.shape[0] + for idx in range(num_junc): + # Fetch one junction + junc = junctions[idx, :] + cv2.circle(image, tuple(np.flip(junc)), radius=junc_size, + color=(0, 255., 0), thickness=3) + # Only plot the junctions which are part of a line segment + else: + for idx in range(segments.shape[0]): + seg = np.round(segments[idx, :]).astype(np.int32) # Already in HW format. + cv2.circle(image, tuple(seg[:2]), radius=junc_size, + color=(0, 255., 0), thickness=3) + cv2.circle(image, tuple(seg[2:]), radius=junc_size, + color=(0, 255., 0), thickness=3) + + return image + + +# Plot line segments given Nx4 or Nx2x2 line segments +def plot_line_segments_from_segments(input_image, line_segments, junc_size=3, + color=(0, 255., 0), line_width=1): + # Create image copy + image = copy.copy(input_image) + # Make sure the image is converted to 255 uint8 + if image.dtype == np.uint8: + pass + # A float type image ranging from 0~1 + elif image.dtype in [np.float32, np.float64, np.float] and image.max() <= 2.: + image = (image * 255.).astype(np.uint8) + # A float type image ranging from 0.~255. + elif image.dtype in [np.float32, np.float64, np.float] and image.mean() > 10.: + image = image.astype(np.uint8) + else: + raise ValueError("[Error] Unknown image data type. Expect 0~1 float or 0~255 uint8.") + + # Check whether the image is single channel + if len(image.shape) == 2 or ((len(image.shape) == 3) and (image.shape[-1] == 1)): + # Squeeze to H*W first + image = image.squeeze() + + # Stack to channle 3 + image = np.concatenate([image[..., None] for _ in range(3)], axis=-1) + + # Check the if line_segments are in (1) Nx4, or (2) Nx2x2. + H, W, _ = image.shape + # (1) Nx4 format + if len(line_segments.shape) == 2 and line_segments.shape[-1] == 4: + # Round to int32 + line_segments = line_segments.astype(np.int32) + + # Clip H dimension + line_segments[:, 0] = np.clip(line_segments[:, 0], a_min=0, a_max=H-1) + line_segments[:, 2] = np.clip(line_segments[:, 2], a_min=0, a_max=H-1) + + # Clip W dimension + line_segments[:, 1] = np.clip(line_segments[:, 1], a_min=0, a_max=W-1) + line_segments[:, 3] = np.clip(line_segments[:, 3], a_min=0, a_max=W-1) + + # Convert to Nx2x2 format + line_segments = np.concatenate( + [np.expand_dims(line_segments[:, :2], axis=1), + np.expand_dims(line_segments[:, 2:], axis=1)], + axis=1 + ) + + # (2) Nx2x2 format + elif len(line_segments.shape) == 3 and line_segments.shape[-1] == 2: + # Round to int32 + line_segments = line_segments.astype(np.int32) + + # Clip H dimension + line_segments[:, :, 0] = np.clip(line_segments[:, :, 0], a_min=0, a_max=H-1) + line_segments[:, :, 1] = np.clip(line_segments[:, :, 1], a_min=0, a_max=W-1) + + else: + raise ValueError("[Error] line_segments should be either Nx4 or Nx2x2 in HW format.") + + # Draw segment pairs (all segments should be in HW format) + image = image.copy() + for idx in range(line_segments.shape[0]): + seg = np.round(line_segments[idx, :, :]).astype(np.int32) + # Decide the color + if color != "random": + color = tuple(color) + else: + color = tuple(np.random.rand(3,)) + cv2.line(image, tuple(np.flip(seg[0, :])), + tuple(np.flip(seg[1, :])), + color=color, thickness=line_width) + + # Also draw the junctions + cv2.circle(image, tuple(np.flip(seg[0, :])), radius=junc_size, color=(0, 255., 0), thickness=3) + cv2.circle(image, tuple(np.flip(seg[1, :])), radius=junc_size, color=(0, 255., 0), thickness=3) + + return image + + +# Additional functions to visualize multiple images at the same time, +# e.g. for line matching +def plot_images(imgs, titles=None, cmaps='gray', dpi=100, size=6, pad=.5): + """Plot a set of images horizontally. + Args: + imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W). + titles: a list of strings, as titles for each image. + cmaps: colormaps for monochrome images. + """ + n = len(imgs) + if not isinstance(cmaps, (list, tuple)): + cmaps = [cmaps] * n + figsize = (size*n, size*3/4) if size is not None else None + fig, ax = plt.subplots(1, n, figsize=figsize, dpi=dpi) + if n == 1: + ax = [ax] + for i in range(n): + ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i])) + ax[i].get_yaxis().set_ticks([]) + ax[i].get_xaxis().set_ticks([]) + ax[i].set_axis_off() + for spine in ax[i].spines.values(): # remove frame + spine.set_visible(False) + if titles: + ax[i].set_title(titles[i]) + fig.tight_layout(pad=pad) + + +def plot_keypoints(kpts, colors='lime', ps=4): + """Plot keypoints for existing images. + Args: + kpts: list of ndarrays of size (N, 2). + colors: string, or list of list of tuples (one for each keypoints). + ps: size of the keypoints as float. + """ + if not isinstance(colors, list): + colors = [colors] * len(kpts) + axes = plt.gcf().axes + for a, k, c in zip(axes, kpts, colors): + a.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0) + + +def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.): + """Plot matches for a pair of existing images. + Args: + kpts0, kpts1: corresponding keypoints of size (N, 2). + color: color of each match, string or RGB tuple. Random if not given. + lw: width of the lines. + ps: size of the end points (no endpoint if ps=0) + indices: indices of the images to draw the matches on. + a: alpha opacity of the match lines. + """ + fig = plt.gcf() + ax = fig.axes + assert len(ax) > max(indices) + ax0, ax1 = ax[indices[0]], ax[indices[1]] + fig.canvas.draw() + + assert len(kpts0) == len(kpts1) + if color is None: + color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist() + elif len(color) > 0 and not isinstance(color[0], (tuple, list)): + color = [color] * len(kpts0) + + if lw > 0: + # transform the points into the figure coordinate system + transFigure = fig.transFigure.inverted() + fkpts0 = transFigure.transform(ax0.transData.transform(kpts0)) + fkpts1 = transFigure.transform(ax1.transData.transform(kpts1)) + fig.lines += [matplotlib.lines.Line2D( + (fkpts0[i, 0], fkpts1[i, 0]), (fkpts0[i, 1], fkpts1[i, 1]), + zorder=1, transform=fig.transFigure, c=color[i], linewidth=lw, + alpha=a) + for i in range(len(kpts0))] + + # freeze the axes to prevent the transform to change + ax0.autoscale(enable=False) + ax1.autoscale(enable=False) + + if ps > 0: + ax0.scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps, zorder=2) + ax1.scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps, zorder=2) + + +def plot_lines(lines, line_colors='orange', point_colors='cyan', + ps=4, lw=2, indices=(0, 1)): + """Plot lines and endpoints for existing images. + Args: + lines: list of ndarrays of size (N, 2, 2). + colors: string, or list of list of tuples (one for each keypoints). + ps: size of the keypoints as float pixels. + lw: line width as float pixels. + indices: indices of the images to draw the matches on. + """ + if not isinstance(line_colors, list): + line_colors = [line_colors] * len(lines) + if not isinstance(point_colors, list): + point_colors = [point_colors] * len(lines) + + fig = plt.gcf() + ax = fig.axes + assert len(ax) > max(indices) + axes = [ax[i] for i in indices] + fig.canvas.draw() + + # Plot the lines and junctions + for a, l, lc, pc in zip(axes, lines, line_colors, point_colors): + for i in range(len(l)): + line = matplotlib.lines.Line2D((l[i, 0, 0], l[i, 1, 0]), + (l[i, 0, 1], l[i, 1, 1]), + zorder=1, c=lc, linewidth=lw) + a.add_line(line) + pts = l.reshape(-1, 2) + a.scatter(pts[:, 0], pts[:, 1], + c=pc, s=ps, linewidths=0, zorder=2) + + +def plot_line_matches(kpts0, kpts1, color=None, lw=1.5, indices=(0, 1), a=1.): + """Plot matches for a pair of existing images, parametrized by their middle point. + Args: + kpts0, kpts1: corresponding middle points of the lines of size (N, 2). + color: color of each match, string or RGB tuple. Random if not given. + lw: width of the lines. + indices: indices of the images to draw the matches on. + a: alpha opacity of the match lines. + """ + fig = plt.gcf() + ax = fig.axes + assert len(ax) > max(indices) + ax0, ax1 = ax[indices[0]], ax[indices[1]] + fig.canvas.draw() + + assert len(kpts0) == len(kpts1) + if color is None: + color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist() + elif len(color) > 0 and not isinstance(color[0], (tuple, list)): + color = [color] * len(kpts0) + + if lw > 0: + # transform the points into the figure coordinate system + transFigure = fig.transFigure.inverted() + fkpts0 = transFigure.transform(ax0.transData.transform(kpts0)) + fkpts1 = transFigure.transform(ax1.transData.transform(kpts1)) + fig.lines += [matplotlib.lines.Line2D( + (fkpts0[i, 0], fkpts1[i, 0]), (fkpts0[i, 1], fkpts1[i, 1]), + zorder=1, transform=fig.transFigure, c=color[i], linewidth=lw, + alpha=a) + for i in range(len(kpts0))] + + # freeze the axes to prevent the transform to change + ax0.autoscale(enable=False) + ax1.autoscale(enable=False) + + +def plot_color_line_matches(lines, correct_matches=None, + lw=2, indices=(0, 1)): + """Plot line matches for existing images with multiple colors. + Args: + lines: list of ndarrays of size (N, 2, 2). + correct_matches: bool array of size (N,) indicating correct matches. + lw: line width as float pixels. + indices: indices of the images to draw the matches on. + """ + n_lines = len(lines[0]) + colors = sns.color_palette('husl', n_colors=n_lines) + np.random.shuffle(colors) + alphas = np.ones(n_lines) + # If correct_matches is not None, display wrong matches with a low alpha + if correct_matches is not None: + alphas[~np.array(correct_matches)] = 0.2 + + fig = plt.gcf() + ax = fig.axes + assert len(ax) > max(indices) + axes = [ax[i] for i in indices] + fig.canvas.draw() + + # Plot the lines + for a, l in zip(axes, lines): + # Transform the points into the figure coordinate system + transFigure = fig.transFigure.inverted() + endpoint0 = transFigure.transform(a.transData.transform(l[:, 0])) + endpoint1 = transFigure.transform(a.transData.transform(l[:, 1])) + fig.lines += [matplotlib.lines.Line2D( + (endpoint0[i, 0], endpoint1[i, 0]), + (endpoint0[i, 1], endpoint1[i, 1]), + zorder=1, transform=fig.transFigure, c=colors[i], + alpha=alphas[i], linewidth=lw) for i in range(n_lines)] + + +def plot_color_lines(lines, correct_matches, wrong_matches, + lw=2, indices=(0, 1)): + """Plot line matches for existing images with multiple colors: + green for correct matches, red for wrong ones, and blue for the rest. + Args: + lines: list of ndarrays of size (N, 2, 2). + correct_matches: list of bool arrays of size N with correct matches. + wrong_matches: list of bool arrays of size (N,) with correct matches. + lw: line width as float pixels. + indices: indices of the images to draw the matches on. + """ + # palette = sns.color_palette() + palette = sns.color_palette("hls", 8) + blue = palette[5] # palette[0] + red = palette[0] # palette[3] + green = palette[2] # palette[2] + colors = [np.array([blue] * len(l)) for l in lines] + for i, c in enumerate(colors): + c[np.array(correct_matches[i])] = green + c[np.array(wrong_matches[i])] = red + + fig = plt.gcf() + ax = fig.axes + assert len(ax) > max(indices) + axes = [ax[i] for i in indices] + fig.canvas.draw() + + # Plot the lines + for a, l, c in zip(axes, lines, colors): + # Transform the points into the figure coordinate system + transFigure = fig.transFigure.inverted() + endpoint0 = transFigure.transform(a.transData.transform(l[:, 0])) + endpoint1 = transFigure.transform(a.transData.transform(l[:, 1])) + fig.lines += [matplotlib.lines.Line2D( + (endpoint0[i, 0], endpoint1[i, 0]), + (endpoint0[i, 1], endpoint1[i, 1]), + zorder=1, transform=fig.transFigure, c=c[i], + linewidth=lw) for i in range(len(l))] + + +def plot_subsegment_matches(lines, subsegments, lw=2, indices=(0, 1)): + """ Plot line matches for existing images with multiple colors and + highlight the actually matched subsegments. + Args: + lines: list of ndarrays of size (N, 2, 2). + subsegments: list of ndarrays of size (N, 2, 2). + lw: line width as float pixels. + indices: indices of the images to draw the matches on. + """ + n_lines = len(lines[0]) + colors = sns.cubehelix_palette(start=2, rot=-0.2, dark=0.3, light=.7, + gamma=1.3, hue=1, n_colors=n_lines) + + fig = plt.gcf() + ax = fig.axes + assert len(ax) > max(indices) + axes = [ax[i] for i in indices] + fig.canvas.draw() + + # Plot the lines + for a, l, ss in zip(axes, lines, subsegments): + # Transform the points into the figure coordinate system + transFigure = fig.transFigure.inverted() + + # Draw full line + endpoint0 = transFigure.transform(a.transData.transform(l[:, 0])) + endpoint1 = transFigure.transform(a.transData.transform(l[:, 1])) + fig.lines += [matplotlib.lines.Line2D( + (endpoint0[i, 0], endpoint1[i, 0]), + (endpoint0[i, 1], endpoint1[i, 1]), + zorder=1, transform=fig.transFigure, c='red', + alpha=0.7, linewidth=lw) for i in range(n_lines)] + + # Draw matched subsegment + endpoint0 = transFigure.transform(a.transData.transform(ss[:, 0])) + endpoint1 = transFigure.transform(a.transData.transform(ss[:, 1])) + fig.lines += [matplotlib.lines.Line2D( + (endpoint0[i, 0], endpoint1[i, 0]), + (endpoint0[i, 1], endpoint1[i, 1]), + zorder=1, transform=fig.transFigure, c=colors[i], + alpha=1, linewidth=lw) for i in range(n_lines)] \ No newline at end of file diff --git a/scalelsd/ssl/models/__init__.py b/scalelsd/ssl/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a0f33847ebe1b18e6cfa9338d0c71ee4002f65ac --- /dev/null +++ b/scalelsd/ssl/models/__init__.py @@ -0,0 +1,3 @@ + +from .detector import ScaleLSD +from . import detector diff --git a/scalelsd/ssl/models/detector.py b/scalelsd/ssl/models/detector.py new file mode 100644 index 0000000000000000000000000000000000000000..32c03c6fcd7130eb6f13c8ffc852e1e849de4116 --- /dev/null +++ b/scalelsd/ssl/models/detector.py @@ -0,0 +1,362 @@ +from collections import defaultdict +import torch +from torch import nn +import torch.nn.functional as F + +import numpy as np +import time + +from ..backbones import build_backbone +from .hafm import HAFMencoder +from .losses import * + +import math +import cv2 +import matplotlib.pyplot as plt + +class ScaleLSD(nn.Module): + def __init__(self, gray_scale=False, use_layer_scale=False, enable_attention_hooks=False): + + super(ScaleLSD, self).__init__() + + num_junctions_inference = 512 + junction_threshold_hm = 0.008 + + self.distance_threshold = 5.0 + + self.hafm_encoder = HAFMencoder(dis_th=self.distance_threshold) + + # self.backbone = build_backbone(gray_scale=gray_scale, use_layer_scale=use_layer_scale) + self.backbone = build_backbone(gray_scale=gray_scale, use_layer_scale=use_layer_scale, enable_attention_hooks=enable_attention_hooks) + + self.j2l_threshold = 10 + + self.num_residuals = 0 + + self.loss = nn.CrossEntropyLoss(reduction='none') + self.bce_loss = nn.BCEWithLogitsLoss(reduction='none') + self.stride = self.backbone.stride + self.train_step = 0 + + @classmethod + def configure(cls, opts): + try: + cls.num_junctions_inference = opts.num_junctions + cls.junction_threshold_hm = opts.junction_hm + except: + pass + + @classmethod + def cli(cls, parser): + try: + parser.add_argument('-nj', '--num-junctions', default=512, type=int, help='number of junctions') + parser.add_argument('-jh', '--junction-hm', default=0.008, type=float, help='junction threshold heatmap') + except: + pass + + def hafm_decoding(self,md_maps, dis_maps, residual_maps, scale=5.0, flatten = True, return_points = False): + + device = md_maps.device + scale = self.distance_threshold + + batch_size, _, height, width = md_maps.shape + _y = torch.arange(0,height,device=device).float() + _x = torch.arange(0,width, device=device).float() + + y0, x0 =torch.meshgrid(_y, _x,indexing='ij') + y0 = y0[None,None] + x0 = x0[None,None] + + sign_pad = torch.arange(-self.num_residuals,self.num_residuals+1,device=device,dtype=torch.float32).reshape(1,-1,1,1) + + if residual_maps is not None: + residual = residual_maps*sign_pad + distance_fields = dis_maps + residual + else: + distance_fields = dis_maps + distance_fields = distance_fields.clamp(min=0,max=1.0) + md_un = (md_maps[:,:1] - 0.5)*np.pi*2 + st_un = md_maps[:,1:2]*np.pi/2.0 + ed_un = -md_maps[:,2:3]*np.pi/2.0 + + cs_md = md_un.cos() + ss_md = md_un.sin() + + y_st = torch.tan(st_un) + y_ed = torch.tan(ed_un) + + x_st_rotated = (cs_md - ss_md*y_st)*distance_fields*scale + y_st_rotated = (ss_md + cs_md*y_st)*distance_fields*scale + + x_ed_rotated = (cs_md - ss_md*y_ed)*distance_fields*scale + y_ed_rotated = (ss_md + cs_md*y_ed)*distance_fields*scale + + x_st_final = (x_st_rotated + x0).clamp(min=0,max=width-1) + y_st_final = (y_st_rotated + y0).clamp(min=0,max=height-1) + + x_ed_final = (x_ed_rotated + x0).clamp(min=0,max=width-1) + y_ed_final = (y_ed_rotated + y0).clamp(min=0,max=height-1) + + + lines = torch.stack((x_st_final,y_st_final,x_ed_final,y_ed_final),dim=-1) + if flatten: + lines = lines.reshape(batch_size,-1,4) + if return_points: + points = torch.stack((x0,y0),dim=-1) + points = points.repeat((batch_size,2*self.num_residuals+1,1,1,1)) + if flatten: + points = points.reshape(batch_size,-1,2) + return lines, points + + return lines + + @staticmethod + def non_maximum_suppression(a, kernel_size=3): + ap = F.max_pool2d(a, kernel_size, stride=1, padding=kernel_size//2) + mask = (a == ap).float().clamp(min=0.0) + + return a * mask + + @staticmethod + def get_junctions(jloc, joff, topk = 300, th=0): + height, width = jloc.size(1), jloc.size(2) + jloc = jloc.reshape(-1) + joff = joff.reshape(2, -1) + + + scores, index = torch.topk(jloc, k=topk) + # y = (index // width).float() + torch.gather(joff[1], 0, index) + 0.5 + y = torch.div(index,width,rounding_mode='trunc').float()+ torch.gather(joff[1], 0, index) + 0.5 + x = (index % width).float() + torch.gather(joff[0], 0, index) + 0.5 + + junctions = torch.stack((x, y)).t() + + if th>0 : + return junctions[scores>th], scores[scores>th] + else: + return junctions, scores + + def wireframe_matcher(self, juncs_pred, lines_pred, hat_points, is_train=False): + cost1 = torch.sum((lines_pred[:,:2]-juncs_pred[:,None])**2,dim=-1) + cost2 = torch.sum((lines_pred[:,2:]-juncs_pred[:,None])**2,dim=-1) + + dis1, idx_junc_to_end1 = cost1.min(dim=0) + dis2, idx_junc_to_end2 = cost2.min(dim=0) + length = torch.sum((lines_pred[:,:2]-lines_pred[:,2:])**2,dim=-1) + + idx_junc_to_end_min = torch.min(idx_junc_to_end1,idx_junc_to_end2) + idx_junc_to_end_max = torch.max(idx_junc_to_end1,idx_junc_to_end2) + + iskeep = idx_junc_to_end_min < idx_junc_to_end_max ## not the same junction + if self.j2l_threshold>0: + iskeep *= (dis10: + for auxput in auxputs: + loss_dict = self.compute_loss(auxput, targets, mask, loss_dict) + + for key in extra_info.keys(): + extra_info[key] = extra_info[key]/batch_size + + return loss_dict, extra_info + + + @torch.no_grad() + def forward_test(self, images, annotations=None, merge=False): + device = images.device + batch_size, _, height, width = images.shape + + outputs, features, aux = self.forward_backbone(images) + + if "use_lsd" not in annotations.keys(): + annotations["use_lsd"] = True + + # use lsd for theta prediction + if annotations['use_lsd']: + ws = images.shape[3]//self.stride + hs = images.shape[2]//self.stride + lsd = cv2.createLineSegmentDetector(0) + + md_lsd_batch = [] + dis_lsd_batch = [] + for i in range(batch_size): + image = np.array(images[i,0].cpu().numpy()*255,dtype=np.uint8) + lsd_lines = lsd.detect(image)[0].reshape(-1,4) + + # transform lsd lines to lsd-hat-field + md_lsd, dis_lsd, _ = self.hafm_encoder.lines2hafm(torch.from_numpy(lsd_lines).to(images.device)/self.stride, hs, ws) + md_lsd_batch.append(md_lsd) + dis_lsd_batch.append(dis_lsd) + + md_pred = torch.stack(md_lsd_batch, dim=0) + dis_pred = torch.stack(dis_lsd_batch, dim=0) + + # for junctions/endpoints extraction + md_pred[:,1:3] = outputs[:,1:3].sigmoid() + # dist + dis_pred = outputs[:,3:4].sigmoid() + + jloc_pred= outputs[:,5:7].softmax(1)[:,1:] + joff_pred= outputs[:,7:9].sigmoid() - 0.5 + else: + + md_pred = outputs[:,:3].sigmoid() + dis_pred = outputs[:,3:4].sigmoid() + res_pred = outputs[:,4:5].sigmoid() + jloc_pred= outputs[:,5:7].softmax(1)[:,1:] + jloc_logits = outputs[:,5:7].softmax(1) + joff_pred= outputs[:,7:9].sigmoid() - 0.5 + + lines_pred_batch, hat_points_batch = self.hafm_decoding(md_pred, dis_pred, None, flatten = True, return_points=True) + + output_list = [] + graph_pred = torch.zeros((batch_size, self.num_junctions_inference, self.num_junctions_inference), device=images.device) + for i in range(batch_size): + if annotations['use_nms']: + jloc_pred_nms = self.non_maximum_suppression(jloc_pred[i]) + else: + jloc_pred_nms = self.non_maximum_suppression(jloc_pred[i], kernel_size=1) + topK = min(self.num_junctions_inference, int((jloc_pred_nms>self.junction_threshold_hm).float().sum().item())) + juncs_pred, juncs_score = self.get_junctions(jloc_pred_nms,joff_pred[i], topk=topK, th=self.junction_threshold_hm) + + lines_adjusted, indices_adj, supports, hat_points, counts = self.wireframe_matcher(juncs_pred, lines_pred_batch[i], hat_points_batch[i]) + + jscales = torch.tensor( + [ + annotations['width']/md_pred.size(3), + annotations['height']/md_pred.size(2) + ], + device=images.device + ) + + junctions = juncs_pred * jscales + supports = [_*self.stride for _ in supports] + + num_junctions = junctions.shape[0] + graph_pred[i, indices_adj[:,0], indices_adj[:,1]] += counts.float() + graph_pred[i, indices_adj[:,1], indices_adj[:,0]] += counts.float() + graph_i = graph_pred[i,:num_junctions,:num_junctions] + edges = graph_i.triu().nonzero() + lines = junctions[edges].reshape(-1,4) + scores = graph_pred[i, edges[:,0], edges[:,1]] + + output_list.append( + { + 'lines_pred': lines, + 'lines_score': scores, + 'juncs_pred': junctions, + 'lines_support': supports, + 'juncs_score': juncs_score, + 'graph': graph_i, + 'width': annotations['width'], + 'height': annotations['height'], + } + ) + + return output_list, {} + diff --git a/scalelsd/ssl/models/hafm.py b/scalelsd/ssl/models/hafm.py new file mode 100644 index 0000000000000000000000000000000000000000..89c1d2dd24a713cd07a8fc25896997d7694770b2 --- /dev/null +++ b/scalelsd/ssl/models/hafm.py @@ -0,0 +1,161 @@ +import torch +import numpy as np +from torch.utils.data.dataloader import default_collate + +from scalelsd.base.csrc import _C + +class HAFMencoder(object): + def __init__(self, dis_th = 10, ang_th = 0): + self.dis_th = dis_th + self.ang_th = ang_th + + def __call__(self,annotations): + targets = [] + metas = [] + batch_size = annotations['batch_size'] + stride = annotations['stride'] + for batch_id in range(batch_size): + + junctions = annotations['junctions'][batch_id].clone()[:,[1,0]]/float(stride) + + width = annotations['width']//stride + height = annotations['height']//stride + edge_indices = annotations['line_map'][batch_id].triu().nonzero() + + t, m = self.encoding_single_image(junctions,edge_indices,height,width) + + targets.append(t) + metas.append(m) + + return default_collate(targets),metas + + def adjacent_matrix(self, n, edges, device): + mat = torch.zeros(n+1,n+1,dtype=torch.bool,device=device) + if edges.size(0)>0: + mat[edges[:,0], edges[:,1]] += True + mat[edges[:,1], edges[:,0]] += True + return mat + + def lines2hafm(self, lines, height, width): + device = lines.device + if lines.shape[0] == 0: + hafm_ang = torch.zeros((3,height,width),device=device) + hafm_dis = torch.zeros((1,height,width),device=device) + hafm_mask = torch.zeros((1,height,width),device=device) + return torch.zeros((3,height,width),device=device), torch.zeros((1,height,width),device=device), torch.zeros((1,height,width),device=device) + + lmap, _, _ = _C.encodels(lines,height,width,height,width,lines.size(0)) + dismap = torch.sqrt(lmap[0]**2+lmap[1]**2)[None] + def _normalize(inp): + mag = torch.sqrt(inp[0]*inp[0]+inp[1]*inp[1]) + return inp/(mag+1e-6) + md_map = _normalize(lmap[:2]) + st_map = _normalize(lmap[2:4]) + ed_map = _normalize(lmap[4:]) + st_map = lmap[2:4] + ed_map = lmap[4:] + + md_ = md_map.reshape(2,-1).t() + st_ = st_map.reshape(2,-1).t() + ed_ = ed_map.reshape(2,-1).t() + Rt = torch.cat( + (torch.cat((md_[:,None,None,0],md_[:,None,None,1]),dim=2), + torch.cat((-md_[:,None,None,1], md_[:,None,None,0]),dim=2)),dim=1) + R = torch.cat( + (torch.cat((md_[:,None,None,0], -md_[:,None,None,1]),dim=2), + torch.cat((md_[:,None,None,1], md_[:,None,None,0]),dim=2)),dim=1) + #Rtst_ = torch.matmul(Rt, st_[:,:,None]).squeeze(-1).t() + #Rted_ = torch.matmul(Rt, ed_[:,:,None]).squeeze(-1).t() + Rtst_ = torch.bmm(Rt, st_[:,:,None]).squeeze(-1).t() + Rted_ = torch.bmm(Rt, ed_[:,:,None]).squeeze(-1).t() + swap_mask = (Rtst_[1]<0)*(Rted_[1]>0) + pos_ = Rtst_.clone() + neg_ = Rted_.clone() + temp = pos_[:,swap_mask] + pos_[:,swap_mask] = neg_[:,swap_mask] + neg_[:,swap_mask] = temp + + pos_[0] = pos_[0].clamp(min=1e-9) + pos_[1] = pos_[1].clamp(min=1e-9) + neg_[0] = neg_[0].clamp(min=1e-9) + neg_[1] = neg_[1].clamp(max=-1e-9) + + mask = (dismap.view(-1)<=self.dis_th).float() + + pos_map = pos_.reshape(-1,height,width) + neg_map = neg_.reshape(-1,height,width) + + md_angle = torch.atan2(md_map[1], md_map[0]) + pos_angle = torch.atan2(pos_map[1],pos_map[0]) + neg_angle = torch.atan2(neg_map[1],neg_map[0]) + + mask *= (pos_angle.reshape(-1)>self.ang_th*np.pi/2.0) + mask *= (neg_angle.reshape(-1)<-self.ang_th*np.pi/2.0) + + pos_angle_n = pos_angle/(np.pi/2) + neg_angle_n = -neg_angle/(np.pi/2) + md_angle_n = md_angle/(np.pi*2) + 0.5 + mask = mask.reshape(height,width) + + + hafm_ang = torch.cat((md_angle_n[None],pos_angle_n[None],neg_angle_n[None],),dim=0) + hafm_dis = dismap.clamp(max=self.dis_th)/self.dis_th + mask = mask[None] + return hafm_ang, hafm_dis, mask + + def encoding_single_image(self, junctions, edge_indices, height, width): + device = junctions.device + + # jmap = torch.zeros((height,width),device=device) + # joff = torch.zeros((2,height,width),device=device,dtype=torch.float32) + jmap = np.zeros((height,width),dtype=np.float32) + joff = np.zeros((2,height,width),dtype=np.float32) + + dx, dy = np.meshgrid(np.arange(width), np.arange(height)) + # gaussian = np.exp(-(dx**2+dy**2)/2.0/2.0**2) + + if junctions.shape[0] > 0: + junctions_np = junctions.cpu().numpy() + xint, yint = junctions_np[:,0].astype(np.int32), junctions_np[:,1].astype(np.int32) + off_x = junctions_np[:,0] - np.floor(junctions_np[:,0]) - 0.5 + off_y = junctions_np[:,1] - np.floor(junctions_np[:,1]) - 0.5 + + jmap[yint,xint] = 1#= jmap[yint,xint] + 1 + joff[0,yint,xint] = off_x + joff[1,yint,xint] = off_y + + lines = junctions[edge_indices].reshape(-1,4) + pos_mat = self.adjacent_matrix(junctions.size(0), edge_indices, device) + labels = torch.ones((lines.shape[0],),device=device) + else: + lines = torch.empty((0,4),device=device) + pos_mat = None + labels = None + # for _x,_y in junctions.cpu().numpy(): + # _map = np.exp(-((dx-_x)**2+(dy-_y)**2)/2.0/8.0**2) + # _map /= _map.max() + # jmap = np.maximum(jmap,_map) + # import matplotlib.pyplot as plt + # import pdb; pdb.set_trace() + jmap = torch.from_numpy(jmap).to(device) + joff = torch.from_numpy(joff).to(device) + hafm_ang, hafm_dis, hafm_mask = self.lines2hafm(lines,height,width) + + + target = { + 'jloc': jmap[None], + 'joff': joff, + 'md': hafm_ang, + 'dis': hafm_dis, + 'mask': hafm_mask + } + + meta = { + 'junc': junctions, + 'lines': lines, + 'Lpos': pos_mat, + 'lpre': lines, + 'lpre_label': labels + } + return target, meta + \ No newline at end of file diff --git a/scalelsd/ssl/models/losses.py b/scalelsd/ssl/models/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..a4fc7fc8c5028a60117f89174b6c2be6325635e2 --- /dev/null +++ b/scalelsd/ssl/models/losses.py @@ -0,0 +1,71 @@ +import torch +import torch.nn.functional as F +def cross_entropy_loss_for_junction(logits, positive): + nlogp = -F.log_softmax(logits, dim=1) + + loss = (positive * nlogp[:, None, 1] + (1 - positive) * nlogp[:, None, 0]) + + return loss.mean() + +def focal_loss_for_junction(logits, positive, gamma=2.0): + prob = F.softmax(logits, 1) + ce_loss = F.cross_entropy(logits, positive, reduction='none') + p_t = prob[:,1:]*positive + prob[:,:1]*(1-positive) + loss = ce_loss * ((1-p_t)**gamma) + + return loss.mean() + +def sigmoid_l1_loss(logits, targets, offset = 0.0, mask=None): + logp = torch.sigmoid(logits) + offset + loss = torch.abs(logp-targets) + + if mask is not None: + w = mask.mean(3, True).mean(2,True) + w[w==0] = 1 + loss = loss*(mask/w) + + return loss.mean() + +def sigmoid_focal_loss( + inputs: torch.Tensor, + targets: torch.Tensor, + alpha: float = -1, + gamma: float = 2, + reduction: str = "none", +) -> torch.Tensor: + """ + Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + alpha: (optional) Weighting factor in range (0,1) to balance + positive vs negative examples. Default = -1 (no weighting). + gamma: Exponent of the modulating factor (1 - p_t) to + balance easy vs hard examples. + reduction: 'none' | 'mean' | 'sum' + 'none': No reduction will be applied to the output. + 'mean': The output will be averaged. + 'sum': The output will be summed. + Returns: + Loss tensor with the reduction option applied. + """ + inputs = inputs.float() + targets = targets.float() + p = torch.sigmoid(inputs) + ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + p_t = p * targets + (1 - p) * (1 - targets) + loss = ce_loss * ((1 - p_t) ** gamma) + + if alpha >= 0: + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + loss = alpha_t * loss + + if reduction == "mean": + loss = loss.mean() + elif reduction == "sum": + loss = loss.sum() + + return loss \ No newline at end of file diff --git a/scripts/predict.sh b/scripts/predict.sh new file mode 100644 index 0000000000000000000000000000000000000000..8074ee0e98c2c28a56325b50e23ab19ec4aeba65 --- /dev/null +++ b/scripts/predict.sh @@ -0,0 +1,18 @@ + +test_root=data-ssl/hybrid_dataset/COCO/val2017 + +threshold=10 +jhm=0.1 +res=512 + +python -m predictor.predict \ + --img $test_root \ + --ext png \ + --threshold $threshold \ + --width $res --height $res \ + --junction-hm $jhm \ + --disable-show \ + # --use_lsd \ + # --use_nms \ + # --whitebg 1. + \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..c93b4b21e82747bad5e69b9720263bdf7e4f35ba --- /dev/null +++ b/setup.py @@ -0,0 +1,37 @@ +import glob +import os + +from setuptools import find_packages +from setuptools import setup + +setup( + name="scalelsd", + version="1.0", + author="Zeran Ke and Nan Xue", + description="Scalable Deep Line Segment Detection Streamlined", + packages=find_packages(), + install_requires=[ + "torch", + "torchvision", + "accelerate", + "tensorboard", + "timm", + "opencv-python-headless==4.8.1.78", + "kornia", + "cython", + "matplotlib", + "yacs", + "scikit-image", + "tqdm", + "python-json-logger", + "h5py", + "shapely", + "seaborn", + "easydict", + ], + extras_require={ + "dev": [ + "pycolmap", + ] + } +)