File size: 7,756 Bytes
05d33a4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 | import torch
import numpy as np
import cv2
import random
import wandb
from tqdm.auto import tqdm
from omegaconf import OmegaConf, DictConfig
from pathlib import Path
def args_to_omegaconf(args, base_cfg=None):
cfg = OmegaConf.create(base_cfg)
def _override_if_provided(container, key):
if hasattr(args, key):
value = getattr(args, key)
if value is not None:
container[key] = value
for key in cfg.keys():
node = cfg[key]
if isinstance(node, DictConfig):
for subkey in node.keys():
_override_if_provided(node, subkey)
else:
_override_if_provided(cfg, key)
return cfg
def _tb_sanitize(v):
if v is None:
return "null"
if isinstance(v, (bool, int, float, str, torch.Tensor)):
return v
if isinstance(v, Path):
return str(v)
return str(v)
def _flatten_dict(d, prefix=""):
out = {}
if isinstance(d, dict):
for k, v in d.items():
key = f"{prefix}.{k}" if prefix else str(k)
if isinstance(v, dict):
out.update(_flatten_dict(v, key))
else:
out[key] = _tb_sanitize(v)
else:
out[prefix or "cfg"] = _tb_sanitize(d)
return out
def convert_paths_to_pathlib(cfg):
for key, value in cfg.items():
if isinstance(value, DictConfig):
cfg[key] = convert_paths_to_pathlib(value)
elif 'path' in key.lower():
cfg[key] = Path(value) if value is not None else None
return cfg
def convert_pathlib_to_strings(cfg):
for key, value in cfg.items():
if isinstance(value, DictConfig):
cfg[key] = convert_pathlib_to_strings(value)
elif isinstance(value, Path):
cfg[key] = str(value)
return cfg
def prepare_trained_parameters(unet, cfg):
unet_parameters = []
if cfg.training.only_train_attention_layers:
for name, param in unet.named_parameters():
if (cfg.model.unet_positional_encoding == "uv" and "conv_in" in name) or \
"transformer_blocks" in name:
unet_parameters.append(param)
param.requires_grad_(True)
else:
param.requires_grad_(False)
else:
for param in unet.parameters():
unet_parameters.append(param)
param.requires_grad_(True)
return unet_parameters
@torch.no_grad()
def validation_loop(accelerator, dataloader, pager, ema_unet, cfg, epoch, global_step, val_type="val"):
if val_type == "val":
desc = "Validation"
x_axis_name = "epoch"
x_axis = epoch
elif val_type == "tiny_val":
desc = "Tiny Validation"
x_axis_name = "global_step"
x_axis = global_step
else:
raise ValueError(f"Unknown val type {val_type}")
if cfg.training.use_EMA:
ema_unet.store(pager.unwrapped_unet.parameters())
ema_unet.copy_to(pager.unwrapped_unet.parameters())
val_epoch_loss = 0.0
log_val_images = {"rgb": [], cfg.model.modality: []}
log_img_ids = random.sample(range(len(dataloader)), 4)
progress_bar = tqdm(dataloader, desc=desc, total=len(dataloader), disable=not accelerator.is_main_process)
for i, batch in enumerate(progress_bar):
pred_cubemap = pager(batch, cfg.model.modality)
if cfg.model.modality == "depth":
min_depth = dataloader.dataset.LOG_MIN_DEPTH if cfg.model.log_scale else dataloader.dataset.MIN_DEPTH
depth_range = dataloader.dataset.LOG_DEPTH_RANGE if cfg.model.log_scale else dataloader.dataset.DEPTH_RANGE
loss = pager.calculate_depth_loss(batch, pred_cubemap, min_depth, depth_range, cfg.model.log_scale, cfg.model.metric_depth)
elif cfg.model.modality == "normal":
loss = pager.calculate_normal_loss(batch, pred_cubemap)
avg_loss = accelerator.reduce(loss["total_loss"].detach(), reduction="mean")
if accelerator.is_main_process:
progress_bar.set_postfix({"loss": avg_loss.item()})
val_epoch_loss += avg_loss
if i in log_img_ids:
log_val_images["rgb"].append(prepare_image_for_logging(batch["rgb"][0].cpu().numpy()))
if cfg.model.modality == "depth":
result_image = pager.process_depth_output(pred_cubemap, orig_size=batch['depth'].shape[2:4], min_depth=min_depth,
depth_range=depth_range, log_scale=cfg.model.log_scale)[1].cpu().numpy()
elif cfg.model.modality == "normal":
result_image = pager.process_normal_output(pred_cubemap, orig_size=batch['normal'].shape[2:4]).cpu().numpy()
log_val_images[cfg.model.modality].append(prepare_image_for_logging(result_image))
val_epoch_loss = val_epoch_loss / len(dataloader)
if accelerator.is_main_process:
accelerator.log({x_axis_name: x_axis, f"{val_type}/loss": float(val_epoch_loss)}, step=global_step)
img_mix_rgb = log_images_mosaic(log_val_images["rgb"])
img_mix_depth = log_images_mosaic(log_val_images[cfg.model.modality])
if cfg.logging.report_to == "wandb":
accelerator.log(
{x_axis_name: x_axis, f"{val_type}/pred_panorama_rgb": wandb.Image(img_mix_rgb)},
step=global_step,
)
accelerator.log(
{x_axis_name: x_axis, f"{val_type}/pred_panorama_{cfg.model.modality}": wandb.Image(img_mix_depth)},
step=global_step,
)
elif cfg.logging.report_to == "tensorboard":
tb_writer = accelerator.get_tracker("tensorboard").writer
tb_writer.add_image(
f"{val_type}/pred_panorama_rgb",
img_mix_rgb,
global_step,
dataformats="HWC",
)
tb_writer.add_image(
f"{val_type}/pred_panorama_{cfg.model.modality}",
img_mix_depth,
global_step,
dataformats="HWC",
)
if cfg.training.use_EMA:
ema_unet.restore(pager.unwrapped_unet.parameters())
return val_epoch_loss
def prepare_image_for_logging(image):
image = (image - image.min()) / (image.max() - image.min() + 1e-8)
image = (image * 255).astype("uint8")
return image
def log_images_mosaic(images):
n = len(images)
assert 1 <= n <= 4, "Provide between 1 and 4 images (CHW uint8)."
fullhd_imgs = []
for img in images:
assert img.dtype == np.uint8 and img.ndim == 3 and img.shape[0] in (1, 3), \
"Each image must be uint8 with shape (C,H,W), C in {1,3}."
if img.shape[0] == 1:
img = np.repeat(img, 3, axis=0)
img_hwc = np.transpose(img, (1, 2, 0))
img_fullhd = cv2.resize(img_hwc, (1920, 1080), interpolation=cv2.INTER_LINEAR)
fullhd_imgs.append(img_fullhd)
H, W, C = 1080, 1920, 3
if n == 1:
return fullhd_imgs[0]
if n == 2:
canvas = np.zeros((H, 2*W, C), dtype=np.uint8)
canvas[:, 0:W, :] = fullhd_imgs[0]
canvas[:, W:2*W, :] = fullhd_imgs[1]
return canvas
if n == 3:
canvas = np.zeros((2*H, 2*W, C), dtype=np.uint8)
x_off = W // 2
canvas[0:H, x_off:x_off+W, :] = fullhd_imgs[0]
canvas[H:2*H, 0:W, :] = fullhd_imgs[1]
canvas[H:2*H, W:2*W, :] = fullhd_imgs[2]
return canvas
canvas = np.zeros((2*H, 2*W, C), dtype=np.uint8)
canvas[0:H, 0:W, :] = fullhd_imgs[0]
canvas[0:H, W:2*W, :] = fullhd_imgs[1]
canvas[H:2*H, 0:W, :] = fullhd_imgs[2]
canvas[H:2*H, W:2*W, :] = fullhd_imgs[3]
return canvas
|