PaGeR / src /pager.py
vulus98's picture
xformers issue fixed
0d7e1ad
import torch
from torch import nn
from torch.nn import Conv2d
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import DDPMScheduler
from diffusers.utils.import_utils import is_xformers_available
from Marigold.unet.unet_2d_condition import UNet2DConditionModel
from Marigold.vae.autoencoder_kl import AutoencoderKL
from src.utils.conv_padding import PaddedConv2d, valid_pad_conv_fn
from src.utils.loss import L1Loss, GradL1Loss, CosineNormalLoss
from src.utils.geometry_utils import (
get_positional_encoding,
compute_scale_and_shift,
compute_shift,
depth_to_normals_erp,
cubemap_to_erp
)
class Pager(nn.Module):
def __init__(self,
model_configs,
pretrained_path,
train_modality=None,
device=torch.device("cpu"),
weight_dtype=torch.float32):
super().__init__()
self.model_configs = model_configs
self.weight_dtype = weight_dtype
self.rgb_latent_scale_factor = 0.18215
self.depth_latent_scale_factor = 0.18215
self.train_modality = train_modality
self.device = device
self.prepare_model_components(pretrained_path, model_configs)
self.prepare_empty_encoding()
self.alpha_prod = self.noise_scheduler.alphas_cumprod.to(device, dtype=weight_dtype)
self.beta_prod = 1 - self.alpha_prod
self.num_timesteps = self.noise_scheduler.config.num_train_timesteps - 1
del self.noise_scheduler
def prepare_model_components(self, pretrained_path, model_configs):
vae_use_RoPE = None
for checkpoint_cfg in model_configs.values():
if vae_use_RoPE is None:
vae_use_RoPE = checkpoint_cfg['config'].vae_use_RoPE == "RoPE"
elif vae_use_RoPE != (checkpoint_cfg['config'].vae_use_RoPE == "RoPE"):
raise ValueError("All UNet checkpoints must use the same VAE positional encoding configuration.")
self.noise_scheduler = DDPMScheduler.from_pretrained(pretrained_path, subfolder="scheduler", rescale_betas_zero_snr=True)
self.tokenizer = CLIPTokenizer.from_pretrained(pretrained_path, subfolder="tokenizer", revision=None)
self.text_encoder = CLIPTextModel.from_pretrained(pretrained_path, subfolder="text_encoder", revision=None, variant=None)
self.vae = AutoencoderKL.from_pretrained(pretrained_path, subfolder="vae", revision=None, variant=None,
use_RoPE = vae_use_RoPE)
self.set_valid_pad_conv(self.vae)
self.vae.requires_grad_(False)
self.vae.to(self.device, dtype=self.weight_dtype)
self.vae.eval()
self.text_encoder.requires_grad_(False)
self.text_encoder.to(self.device, dtype=self.weight_dtype)
self.text_encoder.eval()
base_in_channels = 8
pe_channels_size = 0
self.unet = {}
for modality, checkpoint_cfg in model_configs.items():
if checkpoint_cfg['config'].unet_positional_encoding == "uv":
pe_channels_size = 2
target_in_channels = base_in_channels + pe_channels_size
self.unet[modality] = UNet2DConditionModel.from_pretrained(
checkpoint_cfg["path"],
subfolder="unet",
revision=None,
in_channels=target_in_channels if checkpoint_cfg["mode"] == "trained" else base_in_channels,
use_RoPE=checkpoint_cfg['config'].unet_positional_encoding == "RoPE"
)
if target_in_channels > base_in_channels and checkpoint_cfg["mode"] != "trained":
self.extend_unet_conv_in(self.unet[modality], new_in_channels=target_in_channels)
self.set_valid_pad_conv(self.unet[modality])
if is_xformers_available() and self.device.type == "cuda":
import xformers
if self.unet.get("depth"):
self.unet["depth"].enable_xformers_memory_efficient_attention()
if self.unet.get("normal"):
self.unet["normal"].enable_xformers_memory_efficient_attention()
self.vae.enable_xformers_memory_efficient_attention()
else:
print("xFormers is not available. Proceeding without it.")
def prepare_training(self, accelerator, gradient_checkpointing):
self.unwrapped_unet = self.unet[self.train_modality]
self.unet[self.train_modality] = accelerator.prepare(self.unet[self.train_modality])
self.trained_unet = self.unet[self.train_modality]
if gradient_checkpointing:
self.trained_unet._set_gradient_checkpointing()
self.vae._set_gradient_checkpointing()
def prepare_cubemap_PE(self, image_height, image_width):
use_uv_PE = False
for checkpoint_cfg in self.model_configs.values():
if checkpoint_cfg['config'].unet_positional_encoding == "uv":
use_uv_PE = True
if use_uv_PE:
PE_cubemap = get_positional_encoding(image_height, image_width)
self.PE_cubemap = PE_cubemap.to(device=self.device, dtype=self.weight_dtype)
def prepare_empty_encoding(self):
with torch.inference_mode():
empty_token = self.tokenizer([""], padding="max_length", truncation=True, return_tensors="pt").input_ids
empty_token = empty_token.to(self.device)
empty_encoding = self.text_encoder(empty_token, return_dict=False)[0]
self.empty_encoding = empty_encoding.to(self.device, dtype=self.weight_dtype)
del empty_token
del self.text_encoder
del self.tokenizer
def forward(self, batch, modality):
with torch.inference_mode():
c, h, w = batch["rgb_cubemap"].shape[2:]
rgb_vae_input = batch["rgb_cubemap"].reshape(-1, c, h, w).to(dtype=self.weight_dtype)
rgb_latents = self.vae.encode(rgb_vae_input, deterministic=True)
rgb_latents = rgb_latents * self.rgb_latent_scale_factor
del rgb_vae_input
timesteps = torch.ones((rgb_latents.shape[0],), device=self.device) * self.num_timesteps
timesteps = timesteps.long()
alpha_prod_t = self.alpha_prod[timesteps].view(-1, 1, 1, 1)
beta_prod_t = self.beta_prod[timesteps].view(-1, 1, 1, 1)
noisy_latents = torch.zeros_like(rgb_latents).to(self.device)
encoder_hidden_states = self.empty_encoding.repeat(batch["rgb_cubemap"].shape[0] * 6, 1, 1)
if self.model_configs[modality]['config'].unet_positional_encoding == "uv":
batch_PE_cubemap = self.PE_cubemap.repeat(batch["rgb_cubemap"].shape[0], 1, 1, 1)
unet_input = torch.cat((rgb_latents, noisy_latents, batch_PE_cubemap), dim=1).to(
self.device
)
else:
unet_input = torch.cat((rgb_latents, noisy_latents), dim=1).to(self.device)
del rgb_latents
model_pred = self.unet[modality](
unet_input,
timesteps,
encoder_hidden_states,
return_dict=False,
)[0]
current_latent_estimate = (alpha_prod_t**0.5) * noisy_latents - (beta_prod_t**0.5) * model_pred
current_scaled_latent_estimate = current_latent_estimate / self.depth_latent_scale_factor
pred_cubemap = self.vae.decode(current_scaled_latent_estimate, deterministic=True)
if modality == "depth":
pred_cubemap = pred_cubemap.mean(dim=1, keepdim=True)
return pred_cubemap
def prepare_losses_dict(self, loss_cfg):
self.losses_dict = {}
if self.train_modality == "depth":
self.losses_dict["l1_loss"] = {"loss_fn": L1Loss(invalid_mask_weight=loss_cfg.invalid_mask_weight),
"weight": loss_cfg.l1_loss_weight}
if loss_cfg.grad_loss_weight > 0.0:
self.losses_dict["grad_loss"] = {"loss_fn": GradL1Loss(), "weight": loss_cfg.grad_loss_weight}
if loss_cfg.normals_consistency_loss_weight > 0.0:
self.losses_dict["normals_consistency_loss"] = {"loss_fn": CosineNormalLoss(),
"weight": loss_cfg.normals_consistency_loss_weight}
else:
self.losses_dict["cosine_normal_loss"] = {"loss_fn": CosineNormalLoss(), "weight": 1.0}
def calculate_depth_loss(self, batch, pred_cubemap, min_depth, depth_range, log_scale, metric_depth):
loss = {"total_loss": 0.0}
gt_depth_cubemap = batch['depth_cubemap'].squeeze(0).mean(dim=1, keepdim=True)
mask_cubemap = batch["mask_cubemap"].squeeze(0)
if not metric_depth:
if log_scale:
scale = compute_shift(pred_cubemap, gt_depth_cubemap, mask_cubemap)
else:
scale, shift = compute_scale_and_shift(pred_cubemap, gt_depth_cubemap, mask_cubemap)
if log_scale:
pred_cubemap += scale
else:
pred_cubemap = pred_cubemap * scale + shift
for loss_name, loss_params in self.losses_dict.items():
if loss_name == "normals_consistency_loss":
gt = batch['normal']
pred_depth = pred_cubemap
mask = batch["mask"]
pred_depth = self.process_depth_output(pred_depth, orig_size=gt.shape[2:], min_depth=min_depth,
depth_range=depth_range, log_scale=log_scale)[0]
pred = depth_to_normals_erp(pred_depth).unsqueeze(0)
else:
pred = pred_cubemap
gt = gt_depth_cubemap
mask = mask_cubemap
loss[loss_name] = loss_params["loss_fn"](pred, gt, mask)
loss["total_loss"] += loss[loss_name] * loss_params["weight"]
return loss
def calculate_normal_loss(self, batch, pred_cubemap):
loss = {"total_loss": 0.0}
gt_normal_cubemap = batch['normal_cubemap'].squeeze(0)
mask_cubemap = batch["mask_cubemap"].squeeze(0)
for loss_name, loss_params in self.losses_dict.items():
pred = pred_cubemap
gt = gt_normal_cubemap
loss[loss_name] = loss_params["loss_fn"](pred, gt, mask_cubemap)
loss["total_loss"] += loss[loss_name] * loss_params["weight"]
return loss
def process_depth_output(self,pred_cubemap, orig_size, min_depth, depth_range, log_scale, mask=None):
pred_panorama = cubemap_to_erp(pred_cubemap, *orig_size)
pred_panorama = torch.clamp(pred_panorama, -1, 1)
pred_panorama = (pred_panorama + 1) / 2
if mask is not None:
pred_panorama *= mask
pred_panorama = pred_panorama * depth_range + min_depth
if log_scale:
pred_panorama_viz = pred_panorama.clone()
pred_panorama = torch.exp(pred_panorama)
else:
pred_panorama_viz = torch.log(pred_panorama)
return pred_panorama, pred_panorama_viz
def process_normal_output(self,pred_cubemap, orig_size):
pred_panorama = cubemap_to_erp(pred_cubemap, *orig_size)
pred_panorama = torch.clamp(pred_panorama, -1, 1)
return pred_panorama
def extend_unet_conv_in(self, unet, new_in_channels: int):
if new_in_channels < unet.conv_in.in_channels:
raise ValueError(
f"new_in_channels ({new_in_channels}) must be >= current "
f"{unet.conv_in.in_channels}"
)
if new_in_channels == unet.conv_in.in_channels:
return
old_conv = unet.conv_in
old_in = old_conv.in_channels
device, dtype = old_conv.weight.device, old_conv.weight.dtype
bias_flag = old_conv.bias is not None
new_conv = Conv2d(
new_in_channels,
old_conv.out_channels,
kernel_size=old_conv.kernel_size,
stride=old_conv.stride,
padding=old_conv.padding,
bias=bias_flag,
padding_mode=old_conv.padding_mode,
).to(device=device, dtype=dtype)
new_conv.weight.zero_()
new_conv.weight[:, :old_in].copy_(old_conv.weight)
if bias_flag:
new_conv.bias.copy_(old_conv.bias)
unet.conv_in = new_conv
unet.config["in_channels"] = new_in_channels
def set_valid_pad_conv(self, module: nn.Module):
for name, child in list(module.named_children()):
if isinstance(child, nn.Conv2d):
if child.padding != (0, 0):
setattr(module, name, PaddedConv2d.from_existing(child, valid_pad_conv_fn))
elif module.__class__.__name__ == "Downsample2D" and module.use_conv:
setattr(module, name, PaddedConv2d.from_existing(child, valid_pad_conv_fn, one_side_pad=True))
else:
self.set_valid_pad_conv(child)
def save_model(self, ema_unet, model_save_dir):
self.unwrapped_unet.save_pretrained(model_save_dir / "original")
if ema_unet is not None:
ema_unet.store(self.unwrapped_unet.parameters())
ema_unet.copy_to(self.unwrapped_unet.parameters())
self.unwrapped_unet.save_pretrained(model_save_dir / f"EMA")
ema_unet.restore(self.unwrapped_unet.parameters())