import torch import torch.nn.functional as F from .hessian_penalty import hessian_penalty from .mmd import compute_mmd def compute_rc_loss(model, batch, use_txt_output=False): x = batch["x"] output = batch["output"] mask = batch["mask"] if use_txt_output: output = batch["txt_output"] gtmasked = x.permute(0, 3, 1, 2)[mask] outmasked = output.permute(0, 3, 1, 2)[mask] loss = F.mse_loss(gtmasked, outmasked, reduction='mean') return loss def compute_rcxyz_loss(model, batch, use_txt_output=False): x = batch["x_xyz"] output = batch["output_xyz"] mask = batch["mask"] if use_txt_output: output = batch["txt_output_xyz"] gtmasked = x.permute(0, 3, 1, 2)[mask] outmasked = output.permute(0, 3, 1, 2)[mask] loss = F.mse_loss(gtmasked, outmasked, reduction='mean') return loss def compute_vel_loss(model, batch, use_txt_output=False): x = batch["x"] output = batch["output"] if use_txt_output: output = batch["txt_output"] gtvel = (x[..., 1:] - x[..., :-1]) outputvel = (output[..., 1:] - output[..., :-1]) mask = batch["mask"][..., 1:] gtvelmasked = gtvel.permute(0, 3, 1, 2)[mask] outvelmasked = outputvel.permute(0, 3, 1, 2)[mask] loss = F.mse_loss(gtvelmasked, outvelmasked, reduction='mean') return loss def compute_velxyz_loss(model, batch, use_txt_output=False): x = batch["x_xyz"] output = batch["output_xyz"] if use_txt_output: output = batch["txt_output_xyz"] gtvel = (x[..., 1:] - x[..., :-1]) outputvel = (output[..., 1:] - output[..., :-1]) mask = batch["mask"][..., 1:] gtvelmasked = gtvel.permute(0, 3, 1, 2)[mask] outvelmasked = outputvel.permute(0, 3, 1, 2)[mask] loss = F.mse_loss(gtvelmasked, outvelmasked, reduction='mean') return loss def compute_hp_loss(model, batch): loss = hessian_penalty(model.return_latent, batch, seed=torch.random.seed()) return loss def compute_mmd_loss(model, batch): z = batch["z"] true_samples = torch.randn(z.shape, requires_grad=False, device=model.device) loss = compute_mmd(true_samples, z) return loss _matching_ = {"rc": compute_rc_loss, "hp": compute_hp_loss, "mmd": compute_mmd_loss, "rcxyz": compute_rcxyz_loss, "vel": compute_vel_loss, "velxyz": compute_velxyz_loss} def get_loss_function(ltype): return _matching_[ltype] def get_loss_names(): return list(_matching_.keys())