| import torch |
| from torch import nn |
| from torch.nn import Conv2d |
| from transformers import CLIPTextModel, CLIPTokenizer |
| from diffusers import DDPMScheduler |
| from diffusers.utils.import_utils import is_xformers_available |
| from Marigold.unet.unet_2d_condition import UNet2DConditionModel |
| from Marigold.vae.autoencoder_kl import AutoencoderKL |
| from src.utils.conv_padding import PaddedConv2d, valid_pad_conv_fn |
| from src.utils.loss import L1Loss, GradL1Loss, CosineNormalLoss |
| from src.utils.geometry_utils import ( |
| get_positional_encoding, |
| compute_scale_and_shift, |
| compute_shift, |
| depth_to_normals_erp, |
| cubemap_to_erp |
| ) |
|
|
|
|
| class Pager(nn.Module): |
| def __init__(self, |
| model_configs, |
| pretrained_path, |
| train_modality=None, |
| device=torch.device("cpu"), |
| weight_dtype=torch.float32): |
| super().__init__() |
| self.model_configs = model_configs |
| self.weight_dtype = weight_dtype |
| self.rgb_latent_scale_factor = 0.18215 |
| self.depth_latent_scale_factor = 0.18215 |
| self.train_modality = train_modality |
| self.device = device |
| self.prepare_model_components(pretrained_path, model_configs) |
| self.prepare_empty_encoding() |
|
|
| self.alpha_prod = self.noise_scheduler.alphas_cumprod.to(device, dtype=weight_dtype) |
| self.beta_prod = 1 - self.alpha_prod |
| self.num_timesteps = self.noise_scheduler.config.num_train_timesteps - 1 |
| del self.noise_scheduler |
|
|
|
|
| def prepare_model_components(self, pretrained_path, model_configs): |
| vae_use_RoPE = None |
| for checkpoint_cfg in model_configs.values(): |
| if vae_use_RoPE is None: |
| vae_use_RoPE = checkpoint_cfg['config'].vae_use_RoPE == "RoPE" |
| elif vae_use_RoPE != (checkpoint_cfg['config'].vae_use_RoPE == "RoPE"): |
| raise ValueError("All UNet checkpoints must use the same VAE positional encoding configuration.") |
| |
| self.noise_scheduler = DDPMScheduler.from_pretrained(pretrained_path, subfolder="scheduler", rescale_betas_zero_snr=True) |
| self.tokenizer = CLIPTokenizer.from_pretrained(pretrained_path, subfolder="tokenizer", revision=None) |
| self.text_encoder = CLIPTextModel.from_pretrained(pretrained_path, subfolder="text_encoder", revision=None, variant=None) |
| self.vae = AutoencoderKL.from_pretrained(pretrained_path, subfolder="vae", revision=None, variant=None, |
| use_RoPE = vae_use_RoPE) |
| self.set_valid_pad_conv(self.vae) |
|
|
| self.vae.requires_grad_(False) |
| self.vae.to(self.device, dtype=self.weight_dtype) |
| self.vae.eval() |
|
|
| self.text_encoder.requires_grad_(False) |
| self.text_encoder.to(self.device, dtype=self.weight_dtype) |
| self.text_encoder.eval() |
|
|
|
|
| base_in_channels = 8 |
| pe_channels_size = 0 |
| |
| self.unet = {} |
| for modality, checkpoint_cfg in model_configs.items(): |
| if checkpoint_cfg['config'].unet_positional_encoding == "uv": |
| pe_channels_size = 2 |
| target_in_channels = base_in_channels + pe_channels_size |
|
|
| self.unet[modality] = UNet2DConditionModel.from_pretrained( |
| checkpoint_cfg["path"], |
| subfolder="unet", |
| revision=None, |
| in_channels=target_in_channels if checkpoint_cfg["mode"] == "trained" else base_in_channels, |
| use_RoPE=checkpoint_cfg['config'].unet_positional_encoding == "RoPE" |
| ) |
| |
| if target_in_channels > base_in_channels and checkpoint_cfg["mode"] != "trained": |
| self.extend_unet_conv_in(self.unet[modality], new_in_channels=target_in_channels) |
| self.set_valid_pad_conv(self.unet[modality]) |
|
|
| if is_xformers_available() and self.device.type == "cuda": |
| import xformers |
| if self.unet.get("depth"): |
| self.unet["depth"].enable_xformers_memory_efficient_attention() |
| if self.unet.get("normal"): |
| self.unet["normal"].enable_xformers_memory_efficient_attention() |
| self.vae.enable_xformers_memory_efficient_attention() |
| else: |
| print("xFormers is not available. Proceeding without it.") |
|
|
|
|
| def prepare_training(self, accelerator, gradient_checkpointing): |
| self.unwrapped_unet = self.unet[self.train_modality] |
| self.unet[self.train_modality] = accelerator.prepare(self.unet[self.train_modality]) |
| self.trained_unet = self.unet[self.train_modality] |
|
|
| if gradient_checkpointing: |
| self.trained_unet._set_gradient_checkpointing() |
| self.vae._set_gradient_checkpointing() |
| |
|
|
| def prepare_cubemap_PE(self, image_height, image_width): |
| use_uv_PE = False |
| for checkpoint_cfg in self.model_configs.values(): |
| if checkpoint_cfg['config'].unet_positional_encoding == "uv": |
| use_uv_PE = True |
| if use_uv_PE: |
| PE_cubemap = get_positional_encoding(image_height, image_width) |
| self.PE_cubemap = PE_cubemap.to(device=self.device, dtype=self.weight_dtype) |
|
|
| def prepare_empty_encoding(self): |
| with torch.inference_mode(): |
| empty_token = self.tokenizer([""], padding="max_length", truncation=True, return_tensors="pt").input_ids |
| empty_token = empty_token.to(self.device) |
| empty_encoding = self.text_encoder(empty_token, return_dict=False)[0] |
| self.empty_encoding = empty_encoding.to(self.device, dtype=self.weight_dtype) |
|
|
| del empty_token |
| del self.text_encoder |
| del self.tokenizer |
|
|
|
|
| def forward(self, batch, modality): |
| with torch.inference_mode(): |
| c, h, w = batch["rgb_cubemap"].shape[2:] |
| rgb_vae_input = batch["rgb_cubemap"].reshape(-1, c, h, w).to(dtype=self.weight_dtype) |
| rgb_latents = self.vae.encode(rgb_vae_input, deterministic=True) |
| rgb_latents = rgb_latents * self.rgb_latent_scale_factor |
| del rgb_vae_input |
|
|
| timesteps = torch.ones((rgb_latents.shape[0],), device=self.device) * self.num_timesteps |
| timesteps = timesteps.long() |
| alpha_prod_t = self.alpha_prod[timesteps].view(-1, 1, 1, 1) |
| beta_prod_t = self.beta_prod[timesteps].view(-1, 1, 1, 1) |
|
|
| noisy_latents = torch.zeros_like(rgb_latents).to(self.device) |
| encoder_hidden_states = self.empty_encoding.repeat(batch["rgb_cubemap"].shape[0] * 6, 1, 1) |
| if self.model_configs[modality]['config'].unet_positional_encoding == "uv": |
| batch_PE_cubemap = self.PE_cubemap.repeat(batch["rgb_cubemap"].shape[0], 1, 1, 1) |
| unet_input = torch.cat((rgb_latents, noisy_latents, batch_PE_cubemap), dim=1).to( |
| self.device |
| ) |
| else: |
| unet_input = torch.cat((rgb_latents, noisy_latents), dim=1).to(self.device) |
|
|
| del rgb_latents |
| model_pred = self.unet[modality]( |
| unet_input, |
| timesteps, |
| encoder_hidden_states, |
| return_dict=False, |
| )[0] |
|
|
| current_latent_estimate = (alpha_prod_t**0.5) * noisy_latents - (beta_prod_t**0.5) * model_pred |
| current_scaled_latent_estimate = current_latent_estimate / self.depth_latent_scale_factor |
| pred_cubemap = self.vae.decode(current_scaled_latent_estimate, deterministic=True) |
|
|
| if modality == "depth": |
| pred_cubemap = pred_cubemap.mean(dim=1, keepdim=True) |
| return pred_cubemap |
| |
|
|
| def prepare_losses_dict(self, loss_cfg): |
| self.losses_dict = {} |
| if self.train_modality == "depth": |
| self.losses_dict["l1_loss"] = {"loss_fn": L1Loss(invalid_mask_weight=loss_cfg.invalid_mask_weight), |
| "weight": loss_cfg.l1_loss_weight} |
| if loss_cfg.grad_loss_weight > 0.0: |
| self.losses_dict["grad_loss"] = {"loss_fn": GradL1Loss(), "weight": loss_cfg.grad_loss_weight} |
| if loss_cfg.normals_consistency_loss_weight > 0.0: |
| self.losses_dict["normals_consistency_loss"] = {"loss_fn": CosineNormalLoss(), |
| "weight": loss_cfg.normals_consistency_loss_weight} |
| else: |
| self.losses_dict["cosine_normal_loss"] = {"loss_fn": CosineNormalLoss(), "weight": 1.0} |
|
|
|
|
| def calculate_depth_loss(self, batch, pred_cubemap, min_depth, depth_range, log_scale, metric_depth): |
| loss = {"total_loss": 0.0} |
|
|
| gt_depth_cubemap = batch['depth_cubemap'].squeeze(0).mean(dim=1, keepdim=True) |
| mask_cubemap = batch["mask_cubemap"].squeeze(0) |
| |
| if not metric_depth: |
| if log_scale: |
| scale = compute_shift(pred_cubemap, gt_depth_cubemap, mask_cubemap) |
| else: |
| scale, shift = compute_scale_and_shift(pred_cubemap, gt_depth_cubemap, mask_cubemap) |
|
|
| if log_scale: |
| pred_cubemap += scale |
| else: |
| pred_cubemap = pred_cubemap * scale + shift |
|
|
| for loss_name, loss_params in self.losses_dict.items(): |
| if loss_name == "normals_consistency_loss": |
| gt = batch['normal'] |
| pred_depth = pred_cubemap |
| mask = batch["mask"] |
| pred_depth = self.process_depth_output(pred_depth, orig_size=gt.shape[2:], min_depth=min_depth, |
| depth_range=depth_range, log_scale=log_scale)[0] |
| pred = depth_to_normals_erp(pred_depth).unsqueeze(0) |
| else: |
| pred = pred_cubemap |
| gt = gt_depth_cubemap |
| mask = mask_cubemap |
| loss[loss_name] = loss_params["loss_fn"](pred, gt, mask) |
| loss["total_loss"] += loss[loss_name] * loss_params["weight"] |
|
|
| return loss |
|
|
|
|
| def calculate_normal_loss(self, batch, pred_cubemap): |
| loss = {"total_loss": 0.0} |
|
|
| gt_normal_cubemap = batch['normal_cubemap'].squeeze(0) |
| mask_cubemap = batch["mask_cubemap"].squeeze(0) |
|
|
| for loss_name, loss_params in self.losses_dict.items(): |
| pred = pred_cubemap |
| gt = gt_normal_cubemap |
| loss[loss_name] = loss_params["loss_fn"](pred, gt, mask_cubemap) |
| loss["total_loss"] += loss[loss_name] * loss_params["weight"] |
|
|
| return loss |
| |
| def process_depth_output(self,pred_cubemap, orig_size, min_depth, depth_range, log_scale, mask=None): |
| pred_panorama = cubemap_to_erp(pred_cubemap, *orig_size) |
| pred_panorama = torch.clamp(pred_panorama, -1, 1) |
| pred_panorama = (pred_panorama + 1) / 2 |
| if mask is not None: |
| pred_panorama *= mask |
| pred_panorama = pred_panorama * depth_range + min_depth |
| if log_scale: |
| pred_panorama_viz = pred_panorama.clone() |
| pred_panorama = torch.exp(pred_panorama) |
| else: |
| pred_panorama_viz = torch.log(pred_panorama) |
| |
| return pred_panorama, pred_panorama_viz |
|
|
|
|
| def process_normal_output(self,pred_cubemap, orig_size): |
| pred_panorama = cubemap_to_erp(pred_cubemap, *orig_size) |
| pred_panorama = torch.clamp(pred_panorama, -1, 1) |
| return pred_panorama |
| |
|
|
| def extend_unet_conv_in(self, unet, new_in_channels: int): |
| if new_in_channels < unet.conv_in.in_channels: |
| raise ValueError( |
| f"new_in_channels ({new_in_channels}) must be >= current " |
| f"{unet.conv_in.in_channels}" |
| ) |
| if new_in_channels == unet.conv_in.in_channels: |
| return |
|
|
| old_conv = unet.conv_in |
| old_in = old_conv.in_channels |
| device, dtype = old_conv.weight.device, old_conv.weight.dtype |
| bias_flag = old_conv.bias is not None |
|
|
| new_conv = Conv2d( |
| new_in_channels, |
| old_conv.out_channels, |
| kernel_size=old_conv.kernel_size, |
| stride=old_conv.stride, |
| padding=old_conv.padding, |
| bias=bias_flag, |
| padding_mode=old_conv.padding_mode, |
| ).to(device=device, dtype=dtype) |
|
|
| new_conv.weight.zero_() |
| new_conv.weight[:, :old_in].copy_(old_conv.weight) |
| if bias_flag: |
| new_conv.bias.copy_(old_conv.bias) |
|
|
| unet.conv_in = new_conv |
| unet.config["in_channels"] = new_in_channels |
|
|
|
|
| def set_valid_pad_conv(self, module: nn.Module): |
| for name, child in list(module.named_children()): |
| if isinstance(child, nn.Conv2d): |
| if child.padding != (0, 0): |
| setattr(module, name, PaddedConv2d.from_existing(child, valid_pad_conv_fn)) |
| elif module.__class__.__name__ == "Downsample2D" and module.use_conv: |
| setattr(module, name, PaddedConv2d.from_existing(child, valid_pad_conv_fn, one_side_pad=True)) |
| else: |
| self.set_valid_pad_conv(child) |
|
|
|
|
| def save_model(self, ema_unet, model_save_dir): |
| self.unwrapped_unet.save_pretrained(model_save_dir / "original") |
| if ema_unet is not None: |
| ema_unet.store(self.unwrapped_unet.parameters()) |
| ema_unet.copy_to(self.unwrapped_unet.parameters()) |
| self.unwrapped_unet.save_pretrained(model_save_dir / f"EMA") |
| ema_unet.restore(self.unwrapped_unet.parameters()) |
|
|
|
|
| |
|
|