SRPose / model /pl_trainer.py
FrickYinn's picture
Upload 53 files
e170a8e verified
from typing import Any
import numpy as np
import torch
import lightning as L
from lightglue import SuperPoint, DISK, SIFT, ALIKED
import time
from utils import rotation_angular_error, translation_angular_error, error_auc
from .relpose import RelPose
keypoint_dict = {
'superpoint': SuperPoint,
'disk': DISK,
'sift': SIFT,
'aliked': ALIKED,
}
class PL_RelPose(L.LightningModule):
def __init__(
self,
task,
lr,
epochs,
pct_start,
num_keypoints,
n_layers,
num_heads,
features='superpoint',
):
super().__init__()
self.extractor = keypoint_dict[features](max_num_keypoints=num_keypoints, detection_threshold=0.0).eval()
self.module = RelPose(features=features, task=task, n_layers=n_layers, num_heads=num_heads)
self.criterion = torch.nn.HuberLoss()
self.s_r = torch.nn.Parameter(torch.zeros(1))
# self.s_ta = torch.nn.Parameter(torch. zeros(1))
self.s_t = torch.nn.Parameter(torch.zeros(1))
self.r_errors = {k:[] for k in ['train', 'valid', 'test']}
self.ta_errors = {k:[] for k in ['train', 'valid', 'test']}
self.t_errors = {k:[] for k in ['train', 'valid', 'test']}
self.save_hyperparameters()
def _shared_log(self, mode, loss, loss_r, loss_t, loss_ta, loss_tn):
self.log_dict({
f'{mode}_loss/sum': loss,
f'{mode}_loss/r': loss_r,
f'{mode}_loss/t': loss_t,
f'{mode}_loss/ta': loss_ta,
f'{mode}_loss/tn': loss_tn,
}, on_epoch=True, sync_dist=True)
def training_step(self, batch, batch_idx):
loss, loss_r, loss_ta, loss_t, loss_tn, r_err, ta_err, t_err = self._shared_forward_step(batch, batch_idx)
self.r_errors['train'].append(r_err)
self.ta_errors['train'].append(ta_err)
self.t_errors['train'].append(t_err)
self._shared_log('train', loss, loss_r, loss_t, loss_ta, loss_tn)
return loss
def validation_step(self, batch, batch_idx):
loss, loss_r, loss_ta, loss_t, loss_tn, r_err, ta_err, t_err = self._shared_forward_step(batch, batch_idx)
self.r_errors['valid'].append(r_err)
self.ta_errors['valid'].append(ta_err)
self.t_errors['valid'].append(t_err)
self._shared_log('valid', loss, loss_r, loss_t, loss_ta, loss_tn)
def test_step(self, batch, batch_idx):
loss, loss_r, loss_ta, loss_t, loss_tn, r_err, ta_err, t_err = self._shared_forward_step(batch, batch_idx)
self.r_errors['test'].append(r_err)
self.ta_errors['test'].append(ta_err)
self.t_errors['test'].append(t_err)
self._shared_log('test', loss, loss_r, loss_t, loss_ta, loss_tn)
def _shared_forward_step(self, batch, batch_idx):
images = batch['images']
rotation = batch['rotation']
translation = batch['translation']
intrinsics = batch['intrinsics']
image0 = images[:, 0, ...]
image1 = images[:, 1, ...]
with torch.no_grad():
feats0 = self.extractor({'image': image0})
feats1 = self.extractor({'image': image1})
if 'scales' in batch:
scales = batch['scales']
feats0['keypoints'] *= scales[:, 0].unsqueeze(1)
feats1['keypoints'] *= scales[:, 1].unsqueeze(1)
if self.hparams.task == 'scene':
pred_r, pred_t = self.module({'image0': {**feats0, 'intrinsics': intrinsics[:, 0]}, 'image1': {**feats1, 'intrinsics': intrinsics[:, 1]}})
elif self.hparams.task == 'object':
bboxes = batch['bboxes']
pred_r, pred_t = self.module({'image0': {**feats0, 'intrinsics': intrinsics[:, 0], 'bbox': bboxes[:, 0]}, 'image1': {**feats1, 'intrinsics': intrinsics[:, 1]}})
r_err = rotation_angular_error(pred_r, rotation)
ta_err = translation_angular_error(pred_t, translation)
loss_r = self.criterion(r_err, torch.zeros_like(r_err))
loss_ta = self.criterion(ta_err, torch.zeros_like(ta_err))
loss_tn = self.criterion(pred_t / pred_t.norm(2, dim=-1, keepdim=True), translation / translation.norm(2, dim=-1, keepdim=True))
loss_t = self.criterion(pred_t, translation)
# loss = loss_r * torch.exp(-self.s_r) + loss_t * torch.exp(-self.s_t) + loss_ta * torch.exp(-self.s_ta) + self.s_r + self.s_t + self.s_ta
loss = loss_r + loss_ta + loss_t + loss_tn
r_err = r_err.detach()
ta_err = ta_err.detach()
t_err = (pred_t.detach() - translation).norm(2, dim=1)
return loss, loss_r, loss_ta, loss_t, loss_tn, r_err, ta_err, t_err
def predict_one_data(self, data, device='cuda'):
st_time = time.time()
images = data['images'].to(device)
intrinsics = data['intrinsics'].to(device)
image0 = images[:, 0, ...]
image1 = images[:, 1, ...]
preprocess = time.time()
with torch.no_grad():
feats0 = self.extractor({'image': image0})
feats1 = self.extractor({'image': image1})
extract_time = time.time()
if 'scales' in data:
scales = data['scales'].to(device)
feats0['keypoints'] *= scales[:, 0].unsqueeze(1)
feats1['keypoints'] *= scales[:, 1].unsqueeze(1)
if self.hparams.task == 'scene':
pred_r, pred_t = self.module({'image0': {**feats0, 'intrinsics': intrinsics[:, 0]}, 'image1': {**feats1, 'intrinsics': intrinsics[:, 1]}})
elif self.hparams.task == 'object':
bboxes = data['bboxes'].to(device)
pred_r, pred_t = self.module({'image0': {**feats0, 'intrinsics': intrinsics[:, 0], 'bbox': bboxes[:, 0]}, 'image1': {**feats1, 'intrinsics': intrinsics[:, 1]}})
regress_time = time.time()
return pred_r[0], pred_t[0], preprocess-st_time, extract_time-preprocess, regress_time-extract_time
def _shared_on_epoch_end(self, mode):
r_errors = torch.hstack(self.r_errors[mode]).rad2deg()
ta_errors = torch.hstack(self.ta_errors[mode]).rad2deg()
ta_errors = torch.minimum(ta_errors, 180-ta_errors)
auc = error_auc(torch.maximum(r_errors, ta_errors).cpu(), [5, 10, 20], mode)
t_errors = torch.hstack(self.t_errors[mode])
self.log_dict({
**auc,
f'{mode}_Rot./Avg. Error': r_errors.mean(),
f'{mode}_Rot./Med. Error': r_errors.median(),
f'{mode}_Rot./@30° ACC': (r_errors < 30).float().mean(),
f'{mode}_Rot./@15° ACC': (r_errors < 15).float().mean(),
# f'{mode}_ta/avg': ta_errors.mean(),
# f'{mode}_ta/med': ta_errors.median(),
f'{mode}_Trans./Avg. Error': t_errors.mean(),
f'{mode}_Trans./Med. Error': t_errors.median(),
f'{mode}_Trans./@10cm ACC': (t_errors < 0.1).float().mean(),
f'{mode}_Trans./@1m ACC': (t_errors < 1.0).float().mean(),
}, sync_dist=True)
self.r_errors[mode].clear()
self.ta_errors[mode].clear()
self.t_errors[mode].clear()
def on_train_epoch_end(self):
self._shared_on_epoch_end('train')
def on_validation_epoch_end(self):
self._shared_on_epoch_end('valid')
def on_test_epoch_end(self):
self._shared_on_epoch_end('test')
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.module.parameters(), lr=self.hparams.lr)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=self.hparams.lr, steps_per_epoch=1, epochs=self.hparams.epochs, pct_start=self.hparams.pct_start)
return {
'optimizer': optimizer,
'lr_scheduler': scheduler
}