File size: 10,954 Bytes
436b829 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 | import os
from os.path import join
import numpy as np
import imageio
import torch
import cv2
import pytorch_lightning as pl
from hydra.utils import instantiate
from typing import Any, Dict, List
from ppd.utils.align_depth_func import recover_metric_depth_ransac
from ppd.utils.parallel_utils import async_call
from ppd.utils.logger import Log
from ppd.utils.vis_utils import visualize_depth
class DepthEstimationModel(pl.LightningModule):
def __init__(
self,
pipeline, # The pipeline is the model itself
optimizer, # The optimizer is the optimizer used to train the model
lr_table, # The lr_table is the learning rate table
output_dir: str,
ignored_weights_prefix=["pipeline.sem_encoder"],
save_vis_depth=False, # Whether to save the visualized depth
save_vis_depth_and_concat_img=False,
save_vis_depth_and_concat_gt=True,
**kwargs,
):
super().__init__()
self.pipeline = instantiate(pipeline, _recursive_=False)
self.optimizer = instantiate(optimizer)
self.lr_table = instantiate(lr_table)
self.ignored_weights_prefix = ignored_weights_prefix
self._save_vis_depth = save_vis_depth
self._save_vis_depth_and_concat_img = save_vis_depth_and_concat_img
self._save_vis_depth_and_concat_gt = save_vis_depth_and_concat_gt
self.align_depth_func = recover_metric_depth_ransac
self.output_dir = output_dir
Log.info('Results will be saved to: {}'.format(self.output_dir))
def training_step(self, batch, batch_idx):
output = self.pipeline.forward_train(batch)
if not isinstance(self.trainer.train_dataloader, List):
B = self.trainer.train_dataloader.batch_size
else:
B = np.sum(
[dataloader.batch_size for dataloader in self.trainer.train_dataloader])
loss = output['loss']
if torch.isnan(loss).any() or torch.isinf(loss).any():
raise ValueError(f"Loss is NaN or Inf: {loss}")
self.log('train/loss', loss, on_step=True, on_epoch=True,
prog_bar=True, logger=True, batch_size=B, sync_dist=True)
lr = self.optimizers().param_groups[0]['lr']
self.log('train/lr', lr, on_step=True, on_epoch=True, prog_bar=True, logger=True)
# Save visualization every 100 steps
if self.global_step % 100 == 0:
if 'depth' in output and 'image' in output:
depth_np = output['depth'][0][0].float().detach().cpu().numpy()
rgb_np = output['image'][0].detach().cpu().numpy().transpose((1, 2, 0))
depth_vis = visualize_depth(depth_np)
depth_vis = (depth_vis * 255.).astype(np.uint8)
rgb_vis = (rgb_np * 255.).astype(np.uint8)
vis_img = np.concatenate([rgb_vis, depth_vis], axis=1)
self.logger.experiment.add_image('train/depth_vis',
vis_img.transpose((2,0,1)),
self.global_step)
if 'depth' in output: del output['depth']
return output
def predict_step(self, batch, batch_idx, dataloader_idx=None):
output = self.pipeline.forward_test(batch)
if self._save_vis_depth:
self.save_vis_depth(output['depth'], output['image'], batch['image_name'], 'vis_depth',
gt_depth=batch['depth'] if 'depth' in batch else None)
return output
def validation_step(self, batch, batch_idx, dataloader_idx=None) -> None:
output = self.predict_step(batch, batch_idx, dataloader_idx)
batch_size = batch['image'].shape[0]
metrics_dict = self.compute_metrics(output, batch)
for k, v in metrics_dict.items():
self.log(f'val/{k}', np.mean(v),
on_step=False,
on_epoch=True,
prog_bar=True if 'l1' in k else False,
logger=True,
batch_size=batch_size,
sync_dist=True)
def compute_metrics(self, output, batch):
B = batch['image'].shape[0]
metrics_dict = {}
for b in range(B):
pred_depth = output['depth'][b][0].float().detach().cpu().numpy()
gt_depth = batch['depth'][b][0].float().detach().cpu().numpy()
msk = self.create_depth_mask(batch['dataset_name'], gt_depth)
msk = msk & batch['mask'][b, 0].detach().cpu().numpy().astype(np.bool_)
gt_depth[~msk] = 0.
pred_depth = self.align_depth_func(
pred_depth, gt_depth, msk, log=True)
metrics_dict_item = self.compute_depth_metric(
pred_depth, gt_depth, msk)
metrics_dict = self.update_metrics_dict(
metrics_dict, metrics_dict_item, 'relative')
return metrics_dict
def update_metrics_dict(self, metrics_dict, metrics_dict_item, prefix):
for k, v in metrics_dict_item.items():
if f'{prefix}_{k}' not in metrics_dict:
metrics_dict[f'{prefix}_{k}'] = []
metrics_dict[f'{prefix}_{k}'].append(v)
return metrics_dict
def create_depth_mask(self, dataset_name, gt_depth):
return gt_depth > 1e-3
def compute_depth_metric(self, pred_depth, gt_depth, msk):
gt = gt_depth[msk]
pred = pred_depth[msk]
thresh = np.maximum((gt / (pred + 1e-5)), (pred / (gt + 1e-5)))
d05 = (thresh < 1.25 ** 0.5).mean()
d1 = (thresh < 1.25).mean()
d2 = (thresh < 1.25 ** 2).mean()
d3 = (thresh < 1.25 ** 3).mean()
abs_rel = np.mean(np.abs(gt - pred) / (gt + 1e-5))
return {
'd0.5': d05,
'd1': d1,
'd2': d2,
'd3': d3,
'abs_rel': abs_rel,
}
@async_call
def save_depth(self, depth, name, tag) -> None:
if not isinstance(depth, torch.Tensor):
depth = torch.tensor(depth).unsqueeze(0).unsqueeze(0)
for b in range(len(depth)):
depth_np = depth[b][0].float().detach().cpu().numpy()
last_split_len = len(name[b].split('.')[-1])
save_name = name[b][:-(last_split_len + 1)] + '.npz'
img_path = join(self.output_dir, f'{tag}/{save_name}')
os.makedirs(os.path.dirname(img_path), exist_ok=True)
np.savez_compressed(img_path, data=np.round(depth_np, 3))
@async_call
def save_vis_depth(self, depth, rgb, name, tag, gt_depth=None) -> None:
for b in range(len(depth)):
depth_np = depth[b][0].float().detach().cpu().numpy()
save_name = name[b]
save_imgs = []
save_img = visualize_depth(depth_np,
depth_np.min(),
depth_np.max()
)
save_imgs.append(save_img)
if self._save_vis_depth_and_concat_img:
rgb_np = rgb[b].float().detach().cpu().numpy().transpose((1, 2, 0))
rgb_np = cv2.resize(
rgb_np, (save_img.shape[1], save_img.shape[0]), interpolation=cv2.INTER_AREA)
save_img = np.concatenate(
[rgb_np, save_img], axis=1)
save_imgs.append(rgb_np)
if gt_depth is not None and self._save_vis_depth_and_concat_gt:
gt_depth_np = gt_depth[b][0].float().detach().cpu().numpy()
gt_depth_vis = visualize_depth(gt_depth_np,
gt_depth_np.min(),
gt_depth_np.max()
)
save_img = np.concatenate(
[save_img, gt_depth_vis], axis=1)
save_imgs.append(gt_depth_vis)
img_path = join(self.output_dir, f'{tag}/{save_name}')
os.makedirs(os.path.dirname(img_path), exist_ok=True)
imageio.imwrite(img_path.replace('.jpg', '.png'),
(save_img * 255.).astype(np.uint8))
def configure_optimizers(self):
group_table = {}
params = []
for k, v in self.pipeline.named_parameters():
if v.requires_grad:
group, lr = self.lr_table.get_lr(k)
if lr == 0:
v.requires_grad = False
if group not in group_table:
group_table[group] = len(group_table)
params.append({'params': [v], 'lr': lr, 'name': group})
else:
params[group_table[group]]['params'].append(v)
optimizer = self.optimizer(params=params)
return optimizer
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
for ig_keys in self.ignored_weights_prefix:
Log.debug(f"Remove key `{ig_keys}' from checkpoint.")
for k in list(checkpoint["state_dict"].keys()):
if k.startswith(ig_keys):
checkpoint["state_dict"].pop(k)
super().on_save_checkpoint(checkpoint)
def load_pretrained_model(self, ckpt_path):
"""Load pretrained checkpoint, and assign each weight to the corresponding part."""
Log.info(f"Loading ckpt: {ckpt_path}")
state_dict = torch.load(ckpt_path, "cpu")["state_dict"]
missing, unexpected = self.load_state_dict(state_dict, strict=False)
real_missing = []
for k in missing:
miss = True
for ig_keys in self.ignored_weights_prefix:
if k.startswith(ig_keys):
miss = False
if miss:
real_missing.append(k)
if len(real_missing) > 0:
Log.warn(f"Missing keys: {real_missing}")
if len(unexpected) > 0:
Log.error(f"Unexpected keys: {unexpected}")
def load_pretrained_model_eval(self, ckpt_path):
"""Load pretrained checkpoint, and assign each weight to the corresponding part."""
Log.info(f"Loading ckpt: {ckpt_path}")
state_dict = torch.load(ckpt_path, "cpu")
fixed_state_dict = {}
for k, v in state_dict.items():
if k.startswith("dit."):
fixed_state_dict[f"pipeline.{k}"] = v
else:
fixed_state_dict[k] = v
missing, unexpected = self.load_state_dict(fixed_state_dict, strict=False)
real_missing = []
for k in missing:
miss = True
for ig_keys in self.ignored_weights_prefix:
if k.startswith(ig_keys):
miss = False
if miss:
real_missing.append(k)
if len(real_missing) > 0:
Log.warn(f"Missing keys: {real_missing}")
if len(unexpected) > 0:
Log.error(f"Unexpected keys: {unexpected}")
|