from utils.interpolation_utils.embedding import TimestepEmbedder, get_pos_embedding from utils.interpolation_utils.klperceptual import KLLPIPSWithDiscriminator from utils.interpolation_utils.distributions import DiagonalGaussianDistribution from utils.interpolation_utils.cal_metrics import CalMetrics import torch class InputPadder: def __init__(self, img_size, divisor=32): self.ht, self.wd = img_size pad_ht = (((self.ht // divisor) + 1) * divisor - self.ht) % divisor pad_wd = (((self.wd // divisor) + 1) * divisor - self.wd) % divisor self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, pad_ht - pad_ht // 2] def pad(self, x): return torch.nn.functional.pad(x, self._pad, mode="replicate") def unpad(self, x): ht, wd = x.shape[-2:] c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]] return x[..., c[0]:c[1], c[2]:c[3]] def preprocess_cond(x, eps=1e-8): x_flat = x.flatten(1) x_mean, x_std = torch.mean(x_flat, dim=-1), torch.std(x_flat, dim=-1) + eps while len(x_mean.shape) < len(x.shape): x_mean, x_std = x_mean.unsqueeze(-1), x_std.unsqueeze(-1) x_norm = (x - x_mean) / x_std x_mean_0, x_mean_1 = x_mean.chunk(2, dim=0) x_std_0, x_std_1 = x_std.chunk(2, dim=0) stats = ((x_mean_0 + x_mean_1) / 2, (x_std_0 + x_std_1) / 2) return x_norm, stats def preprocess_frames(frames): frames = frames / 255. frame_0, frame_1, gt = frames[:, 0, ...], frames[:, 1, ...], frames[:, 2, ...] frames = torch.cat((frame_0, frame_1, gt), dim=0) img_size = [frames.shape[2], frames.shape[3]] padder = InputPadder(img_size) return frames, padder, frame_0, frame_1, gt def one_iter_for_vae(model, frames, is_train=True): frames, padder, _, _, gt = preprocess_frames(frames) if not is_train: with torch.no_grad(): recon, posterior = model(padder.pad(frames)) else: recon, posterior = model(padder.pad(frames)) recon = padder.unpad(recon.clamp(0., 1.)) return recon, gt, posterior def one_iter_for_dit(model, vae, frames, transport, sample_fn, vae_mean, vae_scaler, cos_sim_mean, cos_sim_std, is_train=True): frames, padder, frame_0, frame_1, gt = preprocess_frames(frames) cond_frames = torch.cat((frame_0, frame_1), dim=0) difference = ((torch.mean(torch.cosine_similarity(frame_0, frame_1), dim=[1, 2]) - cos_sim_mean) / cos_sim_std).unsqueeze(1).to(frames.device) denoise_args = {"cond_frames": padder.pad(cond_frames), "difference": difference} with torch.no_grad(): posterior, cond_tokens = vae.module.encode(padder.pad(frames)) latent = (posterior.sample() - vae_mean).mul_(vae_scaler) if is_train: loss_dict = transport.training_losses(model, latent, **denoise_args) return loss_dict, latent, cond_tokens, denoise_args else: with torch.no_grad(): noise = torch.randn_like(latent).to(frames.device) samples = sample_fn(noise, model.module.forward, **denoise_args)[-1] generated = vae.module.decode(samples / vae_scaler + vae_mean, cond_tokens) generated = padder.unpad(generated.clamp(0., 1.)) return generated