| | import torch |
| | import numpy as np |
| | import cv2 |
| | import random |
| | import wandb |
| | from tqdm.auto import tqdm |
| | from omegaconf import OmegaConf, DictConfig |
| | from pathlib import Path |
| |
|
| | def args_to_omegaconf(args, base_cfg=None): |
| | cfg = OmegaConf.create(base_cfg) |
| |
|
| | def _override_if_provided(container, key): |
| | if hasattr(args, key): |
| | value = getattr(args, key) |
| | if value is not None: |
| | container[key] = value |
| |
|
| | for key in cfg.keys(): |
| | node = cfg[key] |
| | if isinstance(node, DictConfig): |
| | for subkey in node.keys(): |
| | _override_if_provided(node, subkey) |
| | else: |
| | _override_if_provided(cfg, key) |
| |
|
| | return cfg |
| |
|
| | def _tb_sanitize(v): |
| | if v is None: |
| | return "null" |
| | if isinstance(v, (bool, int, float, str, torch.Tensor)): |
| | return v |
| | if isinstance(v, Path): |
| | return str(v) |
| | return str(v) |
| |
|
| | def _flatten_dict(d, prefix=""): |
| | out = {} |
| | if isinstance(d, dict): |
| | for k, v in d.items(): |
| | key = f"{prefix}.{k}" if prefix else str(k) |
| | if isinstance(v, dict): |
| | out.update(_flatten_dict(v, key)) |
| | else: |
| | out[key] = _tb_sanitize(v) |
| | else: |
| | out[prefix or "cfg"] = _tb_sanitize(d) |
| | return out |
| |
|
| | def convert_paths_to_pathlib(cfg): |
| | for key, value in cfg.items(): |
| | if isinstance(value, DictConfig): |
| | cfg[key] = convert_paths_to_pathlib(value) |
| | elif 'path' in key.lower(): |
| | cfg[key] = Path(value) if value is not None else None |
| | return cfg |
| |
|
| |
|
| | def convert_pathlib_to_strings(cfg): |
| | for key, value in cfg.items(): |
| | if isinstance(value, DictConfig): |
| | cfg[key] = convert_pathlib_to_strings(value) |
| | elif isinstance(value, Path): |
| | cfg[key] = str(value) |
| | return cfg |
| |
|
| |
|
| | def prepare_trained_parameters(unet, cfg): |
| | unet_parameters = [] |
| |
|
| | if cfg.training.only_train_attention_layers: |
| | for name, param in unet.named_parameters(): |
| | if (cfg.model.unet_positional_encoding == "uv" and "conv_in" in name) or \ |
| | "transformer_blocks" in name: |
| | unet_parameters.append(param) |
| | param.requires_grad_(True) |
| | else: |
| | param.requires_grad_(False) |
| | else: |
| | for param in unet.parameters(): |
| | unet_parameters.append(param) |
| | param.requires_grad_(True) |
| | |
| | return unet_parameters |
| |
|
| |
|
| | @torch.no_grad() |
| | def validation_loop(accelerator, dataloader, pager, ema_unet, cfg, epoch, global_step, val_type="val"): |
| | if val_type == "val": |
| | desc = "Validation" |
| | x_axis_name = "epoch" |
| | x_axis = epoch |
| | elif val_type == "tiny_val": |
| | desc = "Tiny Validation" |
| | x_axis_name = "global_step" |
| | x_axis = global_step |
| | else: |
| | raise ValueError(f"Unknown val type {val_type}") |
| | if cfg.training.use_EMA: |
| | ema_unet.store(pager.unwrapped_unet.parameters()) |
| | ema_unet.copy_to(pager.unwrapped_unet.parameters()) |
| | val_epoch_loss = 0.0 |
| | log_val_images = {"rgb": [], cfg.model.modality: []} |
| | log_img_ids = random.sample(range(len(dataloader)), 4) |
| | progress_bar = tqdm(dataloader, desc=desc, total=len(dataloader), disable=not accelerator.is_main_process) |
| | for i, batch in enumerate(progress_bar): |
| | pred_cubemap = pager(batch, cfg.model.modality) |
| | if cfg.model.modality == "depth": |
| | min_depth = dataloader.dataset.LOG_MIN_DEPTH if cfg.model.log_scale else dataloader.dataset.MIN_DEPTH |
| | depth_range = dataloader.dataset.LOG_DEPTH_RANGE if cfg.model.log_scale else dataloader.dataset.DEPTH_RANGE |
| | loss = pager.calculate_depth_loss(batch, pred_cubemap, min_depth, depth_range, cfg.model.log_scale, cfg.model.metric_depth) |
| | elif cfg.model.modality == "normal": |
| | loss = pager.calculate_normal_loss(batch, pred_cubemap) |
| |
|
| | avg_loss = accelerator.reduce(loss["total_loss"].detach(), reduction="mean") |
| | if accelerator.is_main_process: |
| | progress_bar.set_postfix({"loss": avg_loss.item()}) |
| | val_epoch_loss += avg_loss |
| | if i in log_img_ids: |
| | log_val_images["rgb"].append(prepare_image_for_logging(batch["rgb"][0].cpu().numpy())) |
| | if cfg.model.modality == "depth": |
| | result_image = pager.process_depth_output(pred_cubemap, orig_size=batch['depth'].shape[2:4], min_depth=min_depth, |
| | depth_range=depth_range, log_scale=cfg.model.log_scale)[1].cpu().numpy() |
| | elif cfg.model.modality == "normal": |
| | result_image = pager.process_normal_output(pred_cubemap, orig_size=batch['normal'].shape[2:4]).cpu().numpy() |
| | log_val_images[cfg.model.modality].append(prepare_image_for_logging(result_image)) |
| |
|
| | val_epoch_loss = val_epoch_loss / len(dataloader) |
| |
|
| | if accelerator.is_main_process: |
| | accelerator.log({x_axis_name: x_axis, f"{val_type}/loss": float(val_epoch_loss)}, step=global_step) |
| |
|
| | img_mix_rgb = log_images_mosaic(log_val_images["rgb"]) |
| | img_mix_depth = log_images_mosaic(log_val_images[cfg.model.modality]) |
| |
|
| | if cfg.logging.report_to == "wandb": |
| | accelerator.log( |
| | {x_axis_name: x_axis, f"{val_type}/pred_panorama_rgb": wandb.Image(img_mix_rgb)}, |
| | step=global_step, |
| | ) |
| | accelerator.log( |
| | {x_axis_name: x_axis, f"{val_type}/pred_panorama_{cfg.model.modality}": wandb.Image(img_mix_depth)}, |
| | step=global_step, |
| | ) |
| | elif cfg.logging.report_to == "tensorboard": |
| | tb_writer = accelerator.get_tracker("tensorboard").writer |
| | tb_writer.add_image( |
| | f"{val_type}/pred_panorama_rgb", |
| | img_mix_rgb, |
| | global_step, |
| | dataformats="HWC", |
| | ) |
| | tb_writer.add_image( |
| | f"{val_type}/pred_panorama_{cfg.model.modality}", |
| | img_mix_depth, |
| | global_step, |
| | dataformats="HWC", |
| | ) |
| |
|
| | if cfg.training.use_EMA: |
| | ema_unet.restore(pager.unwrapped_unet.parameters()) |
| | return val_epoch_loss |
| |
|
| |
|
| | def prepare_image_for_logging(image): |
| | image = (image - image.min()) / (image.max() - image.min() + 1e-8) |
| | image = (image * 255).astype("uint8") |
| | return image |
| |
|
| |
|
| | def log_images_mosaic(images): |
| | n = len(images) |
| | assert 1 <= n <= 4, "Provide between 1 and 4 images (CHW uint8)." |
| |
|
| | fullhd_imgs = [] |
| | for img in images: |
| | assert img.dtype == np.uint8 and img.ndim == 3 and img.shape[0] in (1, 3), \ |
| | "Each image must be uint8 with shape (C,H,W), C in {1,3}." |
| |
|
| | if img.shape[0] == 1: |
| | img = np.repeat(img, 3, axis=0) |
| | img_hwc = np.transpose(img, (1, 2, 0)) |
| |
|
| | img_fullhd = cv2.resize(img_hwc, (1920, 1080), interpolation=cv2.INTER_LINEAR) |
| | fullhd_imgs.append(img_fullhd) |
| |
|
| | H, W, C = 1080, 1920, 3 |
| |
|
| | if n == 1: |
| | return fullhd_imgs[0] |
| |
|
| | if n == 2: |
| | canvas = np.zeros((H, 2*W, C), dtype=np.uint8) |
| | canvas[:, 0:W, :] = fullhd_imgs[0] |
| | canvas[:, W:2*W, :] = fullhd_imgs[1] |
| | return canvas |
| |
|
| | if n == 3: |
| | canvas = np.zeros((2*H, 2*W, C), dtype=np.uint8) |
| | x_off = W // 2 |
| | canvas[0:H, x_off:x_off+W, :] = fullhd_imgs[0] |
| | canvas[H:2*H, 0:W, :] = fullhd_imgs[1] |
| | canvas[H:2*H, W:2*W, :] = fullhd_imgs[2] |
| | return canvas |
| |
|
| | canvas = np.zeros((2*H, 2*W, C), dtype=np.uint8) |
| | canvas[0:H, 0:W, :] = fullhd_imgs[0] |
| | canvas[0:H, W:2*W, :] = fullhd_imgs[1] |
| | canvas[H:2*H, 0:W, :] = fullhd_imgs[2] |
| | canvas[H:2*H, W:2*W, :] = fullhd_imgs[3] |
| | return canvas |
| |
|
| |
|
| |
|