| from typing import Any
|
| import numpy as np
|
| import torch
|
| import lightning as L
|
| from lightglue import SuperPoint, DISK, SIFT, ALIKED
|
| import time
|
|
|
| from utils import rotation_angular_error, translation_angular_error, error_auc
|
| from .relpose import RelPose
|
|
|
|
|
| keypoint_dict = {
|
| 'superpoint': SuperPoint,
|
| 'disk': DISK,
|
| 'sift': SIFT,
|
| 'aliked': ALIKED,
|
| }
|
|
|
|
|
| class PL_RelPose(L.LightningModule):
|
| def __init__(
|
| self,
|
| task,
|
| lr,
|
| epochs,
|
| pct_start,
|
| num_keypoints,
|
| n_layers,
|
| num_heads,
|
| features='superpoint',
|
| ):
|
| super().__init__()
|
|
|
| self.extractor = keypoint_dict[features](max_num_keypoints=num_keypoints, detection_threshold=0.0).eval()
|
| self.module = RelPose(features=features, task=task, n_layers=n_layers, num_heads=num_heads)
|
| self.criterion = torch.nn.HuberLoss()
|
|
|
| self.s_r = torch.nn.Parameter(torch.zeros(1))
|
|
|
| self.s_t = torch.nn.Parameter(torch.zeros(1))
|
|
|
| self.r_errors = {k:[] for k in ['train', 'valid', 'test']}
|
| self.ta_errors = {k:[] for k in ['train', 'valid', 'test']}
|
| self.t_errors = {k:[] for k in ['train', 'valid', 'test']}
|
|
|
| self.save_hyperparameters()
|
|
|
| def _shared_log(self, mode, loss, loss_r, loss_t, loss_ta, loss_tn):
|
| self.log_dict({
|
| f'{mode}_loss/sum': loss,
|
| f'{mode}_loss/r': loss_r,
|
| f'{mode}_loss/t': loss_t,
|
| f'{mode}_loss/ta': loss_ta,
|
| f'{mode}_loss/tn': loss_tn,
|
| }, on_epoch=True, sync_dist=True)
|
|
|
| def training_step(self, batch, batch_idx):
|
| loss, loss_r, loss_ta, loss_t, loss_tn, r_err, ta_err, t_err = self._shared_forward_step(batch, batch_idx)
|
|
|
| self.r_errors['train'].append(r_err)
|
| self.ta_errors['train'].append(ta_err)
|
| self.t_errors['train'].append(t_err)
|
|
|
| self._shared_log('train', loss, loss_r, loss_t, loss_ta, loss_tn)
|
|
|
| return loss
|
|
|
| def validation_step(self, batch, batch_idx):
|
| loss, loss_r, loss_ta, loss_t, loss_tn, r_err, ta_err, t_err = self._shared_forward_step(batch, batch_idx)
|
|
|
| self.r_errors['valid'].append(r_err)
|
| self.ta_errors['valid'].append(ta_err)
|
| self.t_errors['valid'].append(t_err)
|
|
|
| self._shared_log('valid', loss, loss_r, loss_t, loss_ta, loss_tn)
|
|
|
| def test_step(self, batch, batch_idx):
|
| loss, loss_r, loss_ta, loss_t, loss_tn, r_err, ta_err, t_err = self._shared_forward_step(batch, batch_idx)
|
|
|
| self.r_errors['test'].append(r_err)
|
| self.ta_errors['test'].append(ta_err)
|
| self.t_errors['test'].append(t_err)
|
|
|
| self._shared_log('test', loss, loss_r, loss_t, loss_ta, loss_tn)
|
|
|
| def _shared_forward_step(self, batch, batch_idx):
|
| images = batch['images']
|
| rotation = batch['rotation']
|
| translation = batch['translation']
|
| intrinsics = batch['intrinsics']
|
|
|
| image0 = images[:, 0, ...]
|
| image1 = images[:, 1, ...]
|
|
|
| with torch.no_grad():
|
| feats0 = self.extractor({'image': image0})
|
| feats1 = self.extractor({'image': image1})
|
|
|
| if 'scales' in batch:
|
| scales = batch['scales']
|
| feats0['keypoints'] *= scales[:, 0].unsqueeze(1)
|
| feats1['keypoints'] *= scales[:, 1].unsqueeze(1)
|
|
|
| if self.hparams.task == 'scene':
|
| pred_r, pred_t = self.module({'image0': {**feats0, 'intrinsics': intrinsics[:, 0]}, 'image1': {**feats1, 'intrinsics': intrinsics[:, 1]}})
|
| elif self.hparams.task == 'object':
|
| bboxes = batch['bboxes']
|
| pred_r, pred_t = self.module({'image0': {**feats0, 'intrinsics': intrinsics[:, 0], 'bbox': bboxes[:, 0]}, 'image1': {**feats1, 'intrinsics': intrinsics[:, 1]}})
|
|
|
| r_err = rotation_angular_error(pred_r, rotation)
|
| ta_err = translation_angular_error(pred_t, translation)
|
|
|
| loss_r = self.criterion(r_err, torch.zeros_like(r_err))
|
| loss_ta = self.criterion(ta_err, torch.zeros_like(ta_err))
|
| loss_tn = self.criterion(pred_t / pred_t.norm(2, dim=-1, keepdim=True), translation / translation.norm(2, dim=-1, keepdim=True))
|
| loss_t = self.criterion(pred_t, translation)
|
|
|
|
|
| loss = loss_r + loss_ta + loss_t + loss_tn
|
|
|
| r_err = r_err.detach()
|
| ta_err = ta_err.detach()
|
| t_err = (pred_t.detach() - translation).norm(2, dim=1)
|
|
|
| return loss, loss_r, loss_ta, loss_t, loss_tn, r_err, ta_err, t_err
|
|
|
| def predict_one_data(self, data, device='cuda'):
|
| st_time = time.time()
|
| images = data['images'].to(device)
|
| intrinsics = data['intrinsics'].to(device)
|
|
|
| image0 = images[:, 0, ...]
|
| image1 = images[:, 1, ...]
|
|
|
| preprocess = time.time()
|
|
|
| with torch.no_grad():
|
| feats0 = self.extractor({'image': image0})
|
| feats1 = self.extractor({'image': image1})
|
|
|
| extract_time = time.time()
|
|
|
| if 'scales' in data:
|
| scales = data['scales'].to(device)
|
| feats0['keypoints'] *= scales[:, 0].unsqueeze(1)
|
| feats1['keypoints'] *= scales[:, 1].unsqueeze(1)
|
|
|
| if self.hparams.task == 'scene':
|
| pred_r, pred_t = self.module({'image0': {**feats0, 'intrinsics': intrinsics[:, 0]}, 'image1': {**feats1, 'intrinsics': intrinsics[:, 1]}})
|
| elif self.hparams.task == 'object':
|
| bboxes = data['bboxes'].to(device)
|
| pred_r, pred_t = self.module({'image0': {**feats0, 'intrinsics': intrinsics[:, 0], 'bbox': bboxes[:, 0]}, 'image1': {**feats1, 'intrinsics': intrinsics[:, 1]}})
|
|
|
| regress_time = time.time()
|
|
|
| return pred_r[0], pred_t[0], preprocess-st_time, extract_time-preprocess, regress_time-extract_time
|
|
|
|
|
| def _shared_on_epoch_end(self, mode):
|
| r_errors = torch.hstack(self.r_errors[mode]).rad2deg()
|
| ta_errors = torch.hstack(self.ta_errors[mode]).rad2deg()
|
| ta_errors = torch.minimum(ta_errors, 180-ta_errors)
|
|
|
| auc = error_auc(torch.maximum(r_errors, ta_errors).cpu(), [5, 10, 20], mode)
|
| t_errors = torch.hstack(self.t_errors[mode])
|
|
|
| self.log_dict({
|
| **auc,
|
| f'{mode}_Rot./Avg. Error': r_errors.mean(),
|
| f'{mode}_Rot./Med. Error': r_errors.median(),
|
| f'{mode}_Rot./@30° ACC': (r_errors < 30).float().mean(),
|
| f'{mode}_Rot./@15° ACC': (r_errors < 15).float().mean(),
|
|
|
|
|
| f'{mode}_Trans./Avg. Error': t_errors.mean(),
|
| f'{mode}_Trans./Med. Error': t_errors.median(),
|
| f'{mode}_Trans./@10cm ACC': (t_errors < 0.1).float().mean(),
|
| f'{mode}_Trans./@1m ACC': (t_errors < 1.0).float().mean(),
|
| }, sync_dist=True)
|
|
|
| self.r_errors[mode].clear()
|
| self.ta_errors[mode].clear()
|
| self.t_errors[mode].clear()
|
|
|
| def on_train_epoch_end(self):
|
| self._shared_on_epoch_end('train')
|
|
|
| def on_validation_epoch_end(self):
|
| self._shared_on_epoch_end('valid')
|
|
|
| def on_test_epoch_end(self):
|
| self._shared_on_epoch_end('test')
|
|
|
| def configure_optimizers(self):
|
| optimizer = torch.optim.AdamW(self.module.parameters(), lr=self.hparams.lr)
|
| scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=self.hparams.lr, steps_per_epoch=1, epochs=self.hparams.epochs, pct_start=self.hparams.pct_start)
|
|
|
| return {
|
| 'optimizer': optimizer,
|
| 'lr_scheduler': scheduler
|
| }
|
|
|