from typing import Optional, Union, Tuple from dataclasses import dataclass import torch from torch import nn from torch import Tensor from transformers import PreTrainedModel from transformers.utils import logging, ModelOutput from torchvision.models import vgg16, VGG16_Weights import torch.nn.functional as F from einops import rearrange from .configuration_vae import VAEConfig, EncoderType, DecoderType logger = logging.get_logger(__name__) @dataclass class VAEOutput(ModelOutput): loss: Optional[torch.FloatTensor] = None reconstruction: torch.FloatTensor = None mse_loss: Optional[torch.FloatTensor] = None l1_loss: Optional[torch.FloatTensor] = None perceptual_loss: Optional[torch.FloatTensor] = None dino_loss: Optional[torch.FloatTensor] = None kl_loss: Optional[torch.FloatTensor] = None class Vgg16(nn.Module): # ref https://github.com/dxyang/StyleTransfer/blob/master/vgg.py def __init__(self, layers): super().__init__() features = vgg16(weights=VGG16_Weights.DEFAULT).features self.to_relu_1_2 = nn.Sequential() self.to_relu_2_2 = nn.Sequential() self.to_relu_3_3 = nn.Sequential() self.to_relu_4_3 = nn.Sequential() for x in range(4): self.to_relu_1_2.add_module(str(x), features[x]) for x in range(4, 9): self.to_relu_2_2.add_module(str(x), features[x]) for x in range(9, 16): self.to_relu_3_3.add_module(str(x), features[x]) for x in range(16, 23): self.to_relu_4_3.add_module(str(x), features[x]) # don't need the gradients, just want the features for param in self.parameters(): param.requires_grad = False def forward(self, x): h = self.to_relu_1_2(x) h_relu_1_2 = h h = self.to_relu_2_2(h) h_relu_2_2 = h h = self.to_relu_3_3(h) h_relu_3_3 = h h = self.to_relu_4_3(h) h_relu_4_3 = h out = (h_relu_1_2, h_relu_2_2, h_relu_3_3, h_relu_4_3) return out class PerceptualLoss(nn.Module): def __init__(self, layers=(3, 8, 15, 22), unnorm_mean=None, unnorm_std=None, weights=None): super().__init__() self.vgg = Vgg16(layers=layers) self.layers = layers self.weights = weights or [1.0 / len(layers)] * len(layers) def forward(self, x, y): x_vgg = self.vgg(x) y_vgg = self.vgg(y) loss = 0.0 for x_vgg_layer, y_vgg_layer in zip(x_vgg, y_vgg): loss += F.mse_loss(x_vgg_layer, y_vgg_layer) return loss class DinoLoss(nn.Module): def __init__(self, patch_size, use_large=False): super().__init__() size = 'b' if use_large else 's' dino = f'dino_vit{size}{patch_size}' self.vit = torch.hub.load('facebookresearch/dino:main', dino) print('use ', dino) self.vit.eval() for param in self.vit.parameters(): param.requires_grad = False def forward(self, gt, embed): with torch.no_grad(): dino_features = self.vit.prepare_tokens(gt) for blk in self.vit.blocks: dino_features = blk(dino_features) dino_features = self.vit.norm(dino_features) dino_features = dino_features[:, 1:] embed_features = rearrange(embed, 'b c h w -> b (h w) c').contiguous() dtype = embed.dtype dino_loss = 1 - F.cosine_similarity(dino_features.to(torch.float32), embed_features.to(torch.float32), dim=2) dino_loss = dino_loss.mean() dino_loss = dino_loss.to(dtype) return dino_loss class VAEModel(PreTrainedModel): config_class = VAEConfig main_input_name = "s0_img" def __init__(self, config: VAEConfig): super().__init__(config) dict_config = config.to_dict() self.encoder = EncoderType[config.encoder_type].value(**dict_config) enc_out_dim = self.config.z_channels * (self.config.resolution[0] // (2 ** (len(self.config.channels_mult) - 1))) ** 2 latent_dim = 64 self.cond_mlp = nn.Sequential( nn.Linear(enc_out_dim * 2, config.z_channels), nn.ReLU(), nn.Linear(config.z_channels, config.z_channels), nn.ReLU(), nn.Linear(config.z_channels, latent_dim * 2), ) self.in_mlp = nn.Sequential( nn.Linear(enc_out_dim, config.z_channels), nn.ReLU(), nn.Linear(config.z_channels, config.z_channels), nn.ReLU(), nn.Linear(config.z_channels, latent_dim * 2), ) self.cond_mlp_out = nn.Sequential( nn.Linear(latent_dim + enc_out_dim, config.z_channels), nn.ReLU(), nn.Linear(config.z_channels, config.z_channels), nn.ReLU(), nn.Linear(config.z_channels, enc_out_dim), ) self.out_mlp = nn.Sequential( nn.Linear(latent_dim, config.z_channels), nn.ReLU(), nn.Linear(config.z_channels, config.z_channels), nn.ReLU(), nn.Linear(config.z_channels, enc_out_dim), ) self.decoder = DecoderType[config.decoder_type].value(**dict_config) if config.w_perceptual > 0: self.perceptual_loss = PerceptualLoss( unnorm_mean=config.image_mean, unnorm_std=config.image_std ) if config.w_dino > 0: assert config.z_channels in [384, 768] patch_size = 2 ** (len(config.channels_mult) - 1) self.dino_loss = DinoLoss(patch_size=patch_size) self.log_state = { "loss": None, "mse_loss": None, "l1_loss": None, "perceptual_loss": None, "dino_loss": None, "gt": None, "recon": None, } self.post_init() def encode(self, s0_img: Tensor, s1_img: Tensor, a0: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]: # s0 = self.encoder(s0_img).reshape(s0_img.shape[0], -1) s0 = None s1 = self.encoder(s1_img).reshape(s1_img.shape[0], -1) # s1_mean_var = self.cond_mlp(torch.cat([s0, s1], dim=1)) s1_mean_var = self.in_mlp(s1) s1_mean, s1_logvar = s1_mean_var.chunk(2, dim=1) s1_stddev = torch.exp(s1_logvar * 0.5) s1_latent = s1_mean + s1_stddev * torch.randn_like(s1_mean) return s1_latent, s0, s1_mean, s1_logvar def decode(self, s1_latent: Tensor, s0: Tensor) -> Tensor: quant_h = int(self.config.resolution[0] / (2 ** (len(self.config.channels_mult) - 1))) quant_w = int(self.config.resolution[1] / (2 ** (len(self.config.channels_mult) - 1))) # s1_latent = self.cond_mlp_out(torch.cat([s1_latent, s0], dim=1)).reshape(s1_latent.shape[0], self.config.z_channels, quant_h, quant_w) s1_latent = self.out_mlp(s1_latent).reshape(s1_latent.shape[0], self.config.z_channels, quant_h, quant_w) return self.decoder(s1_latent) def forward(self, s0_img: Tensor, s1_img: Tensor, action: Tensor, return_loss: bool = True, return_dict: Optional[bool] = None, ) -> Union[Tuple, VAEOutput]: return_dict = return_dict if return_dict is not None else False s1_latent, s0, s1_mean, s1_logvar = self.encode(s0_img, s1_img, action) recon = self.decode(s1_latent, s0) loss = None if return_loss: # recon loss mse_loss = F.mse_loss(recon, s1_img) l1_loss = F.l1_loss(recon, s1_img) if self.config.w_perceptual > 0: perceptual_loss = self.perceptual_loss(recon, s1_img) else: perceptual_loss = torch.zeros_like(mse_loss).to(mse_loss.device) if self.config.w_dino > 0: dino_loss = self.dino_loss(s1_img, None) else: dino_loss = torch.zeros_like(mse_loss).to(mse_loss.device) # kl loss kl_loss = torch.mean(-0.5 * torch.sum(1 + s1_logvar - s1_mean**2 - s1_logvar.exp(), dim=1)) loss = self.config.w_mse * mse_loss + \ self.config.w_l1 * l1_loss + \ self.config.w_perceptual * perceptual_loss + \ self.config.w_dino * dino_loss + \ self.config.w_kl * kl_loss if not return_dict: self.log_state["loss"] = loss.item() self.log_state["mse_loss"] = mse_loss.item() self.log_state["l1_loss"] = l1_loss.item() self.log_state["perceptual_loss"] = perceptual_loss.item() self.log_state["dino_loss"] = dino_loss.item() self.log_state["kl_loss"] = kl_loss.item() self.log_state["gt"] = s0_img.clone().detach().cpu()[:4].to(torch.float32) self.log_state["recon"] = recon.clone().detach().cpu()[:4].to(torch.float32) return ((loss,) + (recon,)) if loss is not None else recon return VAEOutput( loss=loss, reconstruction=recon, mse_loss=mse_loss, l1_loss=l1_loss, perceptual_loss=perceptual_loss, dino_loss=dino_loss, ) def get_last_layer(self): raise NotImplementedError