| import numpy as np |
| import torch |
| from pytlsd import lsd |
| from sklearn.cluster import DBSCAN |
| import sys |
|
|
| from gluestick.models.base_model import BaseModel |
| from gluestick.models.superpoint import SuperPoint, sample_descriptors |
| from gluestick.geometry import warp_lines_torch |
|
|
| from pathlib import Path |
| import copy, cv2 |
| import os, glob |
| import scalelsd |
| from scalelsd.ssl.models.detector import ScaleLSD |
| from scalelsd.ssl.misc.train_utils import fix_seeds, load_scalelsd_model |
|
|
|
|
| def lines_to_wireframe(lines, line_scores, all_descs, conf): |
| """ Given a set of lines, their score and dense descriptors, |
| merge close-by endpoints and compute a wireframe defined by |
| its junctions and connectivity. |
| Returns: |
| junctions: list of [num_junc, 2] tensors listing all wireframe junctions |
| junc_scores: list of [num_junc] tensors with the junction score |
| junc_descs: list of [dim, num_junc] tensors with the junction descriptors |
| connectivity: list of [num_junc, num_junc] bool arrays with True when 2 junctions are connected |
| new_lines: the new set of [b_size, num_lines, 2, 2] lines |
| lines_junc_idx: a [b_size, num_lines, 2] tensor with the indices of the junctions of each endpoint |
| num_true_junctions: a list of the number of valid junctions for each image in the batch, |
| i.e. before filling with random ones |
| """ |
| b_size, _, _, _ = all_descs.shape |
| device = lines.device |
| endpoints = lines.reshape(b_size, -1, 2) |
|
|
| (junctions, junc_scores, junc_descs, connectivity, new_lines, |
| lines_junc_idx, num_true_junctions) = [], [], [], [], [], [], [] |
| for bs in range(b_size): |
| |
| db = DBSCAN(eps=conf.nms_radius, min_samples=1).fit( |
| endpoints[bs].cpu().numpy()) |
| clusters = db.labels_ |
| n_clusters = len(set(clusters)) |
| num_true_junctions.append(n_clusters) |
|
|
| |
| clusters = torch.tensor(clusters, dtype=torch.long, |
| device=device) |
| new_junc = torch.zeros(n_clusters, 2, dtype=torch.float, |
| device=device) |
| new_junc.scatter_reduce_(0, clusters[:, None].repeat(1, 2), |
| endpoints[bs], reduce='mean', |
| include_self=False) |
| junctions.append(new_junc) |
| new_scores = torch.zeros(n_clusters, dtype=torch.float, device=device) |
| new_scores.scatter_reduce_( |
| 0, clusters, torch.repeat_interleave(line_scores[bs], 2), |
| reduce='mean', include_self=False) |
| junc_scores.append(new_scores) |
|
|
| |
| new_lines.append(junctions[-1][clusters].reshape(-1, 2, 2)) |
| lines_junc_idx.append(clusters.reshape(-1, 2)) |
|
|
| |
| junc_connect = torch.eye(n_clusters, dtype=torch.bool, |
| device=device) |
| pairs = clusters.reshape(-1, 2) |
| junc_connect[pairs[:, 0], pairs[:, 1]] = True |
| junc_connect[pairs[:, 1], pairs[:, 0]] = True |
| connectivity.append(junc_connect) |
|
|
| |
| junc_descs.append(sample_descriptors( |
| junctions[-1][None], all_descs[bs:(bs + 1)], 8)[0]) |
|
|
| new_lines = torch.stack(new_lines, dim=0) |
| lines_junc_idx = torch.stack(lines_junc_idx, dim=0) |
| return (junctions, junc_scores, junc_descs, connectivity, |
| new_lines, lines_junc_idx, num_true_junctions) |
|
|
|
|
| class SPWireframeDescriptor(BaseModel): |
| default_conf = { |
| 'sp_params': { |
| 'has_detector': True, |
| 'has_descriptor': True, |
| 'descriptor_dim': 256, |
| 'trainable': False, |
|
|
| |
| 'return_all': True, |
| 'sparse_outputs': True, |
| 'nms_radius': 4, |
| 'detection_threshold': 0.005, |
| 'max_num_keypoints': 1000, |
| 'force_num_keypoints': True, |
| 'remove_borders': 4, |
| }, |
| 'wireframe_params': { |
| 'merge_points': True, |
| 'merge_line_endpoints': True, |
| 'nms_radius': 3, |
| 'max_n_junctions': 500, |
| }, |
| 'max_n_lines': 250, |
| 'min_length': 15, |
| } |
| required_data_keys = ['image'] |
|
|
| def _init(self, conf): |
| self.conf = conf |
| self.sp = SuperPoint(conf.sp_params) |
| self.extr_conf = {} |
|
|
| def detect_lsd_lines(self, x, max_n_lines=None): |
| if max_n_lines is None: |
| max_n_lines = self.conf.max_n_lines |
| lines, scores, valid_lines = [], [], [] |
| for b in range(len(x)): |
| |
| img = (x[b].squeeze().cpu().numpy() * 255).astype(np.uint8) |
| if max_n_lines is None: |
| b_segs = lsd(img) |
| else: |
| for s in [0.3, 0.4, 0.5, 0.7, 0.8, 1.0]: |
| b_segs = lsd(img, scale=s) |
| if len(b_segs) >= max_n_lines: |
| break |
|
|
| segs_length = np.linalg.norm(b_segs[:, 2:4] - b_segs[:, 0:2], axis=1) |
| |
| b_segs = b_segs[segs_length >= self.conf.min_length] |
| segs_length = segs_length[segs_length >= self.conf.min_length] |
| b_scores = b_segs[:, -1] * np.sqrt(segs_length) |
| |
| indices = np.argsort(-b_scores) |
| if max_n_lines is not None: |
| indices = indices[:max_n_lines] |
| lines.append(torch.from_numpy(b_segs[indices, :4].reshape(-1, 2, 2))) |
| scores.append(torch.from_numpy(b_scores[indices])) |
| valid_lines.append(torch.ones_like(scores[-1], dtype=torch.bool)) |
|
|
| lines = torch.stack(lines).to(x) |
| scores = torch.stack(scores).to(x) |
| valid_lines = torch.stack(valid_lines).to(x.device) |
| return lines, scores, valid_lines |
|
|
| def update_conf(self, conf): |
| self.extr_conf = conf |
|
|
| def _forward(self, data): |
| b_size, _, h, w = data['image'].shape |
| device = data['image'].device |
| |
|
|
| if not self.conf.sp_params.force_num_keypoints: |
| assert b_size == 1, "Only batch size of 1 accepted for non padded inputs" |
|
|
| |
| if 'lines' not in data or 'line_scores' not in data: |
| if self.extr_conf is None: |
| ckpt = 'models/scalelsd-vitbase-v1-train-sa1b.pt' |
| model = load_scalelsd_model(ckpt, device) |
| model.junction_threshold_hm = 0.008 |
| threshold = 5 |
| model.num_junctions_inference = 4096 |
| size = 512 |
| image = data['image'] |
| image_size = image.shape[-2:] |
| image_np = image[0,0].cpu().numpy() |
| image_cp = copy.deepcopy(image_np) |
| image_torch = torch.from_numpy(cv2.resize(image_cp, (size, size))).float() |
| image_cuda = image_torch[None,None].to(device) |
| meta = { |
| 'width': image_size[1], |
| 'height':image_size[0], |
| 'filename': '', |
| 'use_lsd': False, |
| 'use_nms': False, |
| } |
| outputs, _ = model(image_cuda, meta) |
| lines = outputs[0]['lines_pred'] |
| line_scores = outputs[0]['lines_score'] |
| lines = lines[line_scores>=threshold] |
| line_scores = line_scores[line_scores>=threshold][None] |
| elif self.extr_conf['model_name'] != 'lsd': |
| |
| ckpt = "models/" + self.extr_conf['model_name'] |
| model = load_scalelsd_model(ckpt, device) |
| |
| model.junction_threshold_hm = self.extr_conf['junction_threshold_hm'] |
| model.num_junctions_inference = self.extr_conf['num_junctions_inference'] |
| width, height = self.extr_conf['width'], self.extr_conf['height'] |
|
|
| image = data['image'] |
| image_size = image.shape[-2:] |
| image_np = image[0,0].cpu().numpy() |
| image_cp = copy.deepcopy(image_np) |
| image_torch = torch.from_numpy(cv2.resize(image_cp, (width, height))).float() |
| image_cuda = image_torch[None,None].to(device) |
| meta = { |
| 'width': image_size[1], |
| 'height':image_size[0], |
| 'filename': '', |
| 'use_lsd': self.extr_conf['use_lsd'], |
| 'use_nms': self.extr_conf['use_nms'], |
| } |
| outputs, _ = model(image_cuda, meta) |
| lines = outputs[0]['lines_pred'] |
| line_scores = outputs[0]['lines_score'] |
| lines = lines[line_scores>=self.extr_conf['threshold']] |
| line_scores = line_scores[line_scores>=self.extr_conf['threshold']][None] |
| else: |
| if 'original_img' in data: |
| |
| lines, line_scores, valid_lines = self.detect_lsd_lines( |
| data['original_img'], self.conf.max_n_lines * 3) |
| |
| lines, valid_lines2 = warp_lines_torch(lines, data['H'], False, data['image'].shape[-2:]) |
| valid_lines = valid_lines & valid_lines2 |
| lines[~valid_lines] = -1 |
| line_scores[~valid_lines] = 0 |
| |
| sorted_scores, sorting_indices = torch.sort(line_scores, dim=-1, descending=True) |
| line_scores = sorted_scores[:, :self.conf.max_n_lines] |
| sorting_indices = sorting_indices[:, :self.conf.max_n_lines] |
| lines = torch.take_along_dim(lines, sorting_indices[..., None, None], 1) |
| valid_lines = torch.take_along_dim(valid_lines, sorting_indices, 1) |
| else: |
| lines, line_scores, valid_lines = self.detect_lsd_lines(data['image'],max_n_lines=1000000) |
|
|
| else: |
| lines, line_scores, valid_lines = data['lines'], data['line_scores'], data['valid_lines'] |
| if line_scores.shape[-1] != 0: |
| line_scores /= (line_scores.new_tensor(1e-8) + line_scores.max(dim=1).values[:, None]) |
|
|
| |
| pred = self.sp(data) |
|
|
| |
| if self.conf.wireframe_params.merge_points: |
| kp = pred['keypoints'] |
| line_endpts = lines.reshape(b_size, -1, 2) |
| dist_pt_lines = torch.norm( |
| kp[:, :, None] - line_endpts[:, None], dim=-1) |
| |
| pts_to_remove = torch.any( |
| dist_pt_lines < self.conf.sp_params.nms_radius, dim=2) |
| |
| assert len(kp) == 1 |
| pred['keypoints'] = pred['keypoints'][0][~pts_to_remove[0]][None] |
| pred['keypoint_scores'] = pred['keypoint_scores'][0][~pts_to_remove[0]][None] |
| pred['descriptors'] = pred['descriptors'][0].T[~pts_to_remove[0]].T[None] |
|
|
| |
| orig_lines = lines.clone() |
| if self.conf.wireframe_params.merge_line_endpoints and len(lines[0]) > 0: |
| |
| (line_points, line_pts_scores, line_descs, line_association, |
| lines, lines_junc_idx, num_true_junctions) = lines_to_wireframe( |
| lines, line_scores, pred['all_descriptors'], |
| conf=self.conf.wireframe_params) |
|
|
| |
| (all_points, all_scores, all_descs, |
| pl_associativity) = [], [], [], [] |
| for bs in range(b_size): |
| all_points.append(torch.cat( |
| [line_points[bs], pred['keypoints'][bs]], dim=0)) |
| all_scores.append(torch.cat( |
| [line_pts_scores[bs], pred['keypoint_scores'][bs]], dim=0)) |
| all_descs.append(torch.cat( |
| [line_descs[bs], pred['descriptors'][bs]], dim=1)) |
|
|
| associativity = torch.eye(len(all_points[-1]), dtype=torch.bool, device=device) |
| associativity[:num_true_junctions[bs], :num_true_junctions[bs]] = \ |
| line_association[bs][:num_true_junctions[bs], :num_true_junctions[bs]] |
| pl_associativity.append(associativity) |
|
|
| all_points = torch.stack(all_points, dim=0) |
| all_scores = torch.stack(all_scores, dim=0) |
| all_descs = torch.stack(all_descs, dim=0) |
| pl_associativity = torch.stack(pl_associativity, dim=0) |
| else: |
| |
| all_points = torch.cat([lines.reshape(b_size, -1, 2), |
| pred['keypoints']], dim=1) |
| n_pts = all_points.shape[1] |
| num_lines = lines.shape[1] |
| num_true_junctions = [num_lines * 2] * b_size |
| all_scores = torch.cat([ |
| torch.repeat_interleave(line_scores, 2, dim=1), |
| pred['keypoint_scores']], dim=1) |
| pred['line_descriptors'] = self.endpoints_pooling( |
| lines, pred['all_descriptors'], (h, w)) |
| all_descs = torch.cat([ |
| pred['line_descriptors'].reshape(b_size, self.conf.sp_params.descriptor_dim, -1), |
| pred['descriptors']], dim=2) |
| pl_associativity = torch.eye( |
| n_pts, dtype=torch.bool, |
| device=device)[None].repeat(b_size, 1, 1) |
| lines_junc_idx = torch.arange( |
| num_lines * 2, device=device).reshape(1, -1, 2).repeat(b_size, 1, 1) |
|
|
| del pred['all_descriptors'] |
| torch.cuda.empty_cache() |
|
|
| return {'keypoints': all_points, |
| 'keypoint_scores': all_scores, |
| 'descriptors': all_descs, |
| 'pl_associativity': pl_associativity, |
| 'num_junctions': torch.tensor(num_true_junctions), |
| 'lines': lines, |
| 'orig_lines': orig_lines, |
| 'lines_junc_idx': lines_junc_idx, |
| 'line_scores': line_scores, |
| |
| } |
|
|
| @staticmethod |
| def endpoints_pooling(segs, all_descriptors, img_shape): |
| assert segs.ndim == 4 and segs.shape[-2:] == (2, 2) |
| filter_shape = all_descriptors.shape[-2:] |
| scale_x = filter_shape[1] / img_shape[1] |
| scale_y = filter_shape[0] / img_shape[0] |
|
|
| scaled_segs = torch.round(segs * torch.tensor([scale_x, scale_y]).to(segs)).long() |
| scaled_segs[..., 0] = torch.clip(scaled_segs[..., 0], 0, filter_shape[1] - 1) |
| scaled_segs[..., 1] = torch.clip(scaled_segs[..., 1], 0, filter_shape[0] - 1) |
| line_descriptors = [all_descriptors[None, b, ..., torch.squeeze(b_segs[..., 1]), torch.squeeze(b_segs[..., 0])] |
| for b, b_segs in enumerate(scaled_segs)] |
| line_descriptors = torch.cat(line_descriptors) |
| return line_descriptors |
|
|
| def loss(self, pred, data): |
| raise NotImplementedError |
|
|
| def metrics(self, pred, data): |
| return {} |
|
|