Pedro-Quesado's picture
Upload 154 files
81f9834 verified
from utils.interpolation_utils.embedding import TimestepEmbedder, get_pos_embedding
from utils.interpolation_utils.klperceptual import KLLPIPSWithDiscriminator
from utils.interpolation_utils.distributions import DiagonalGaussianDistribution
from utils.interpolation_utils.cal_metrics import CalMetrics
import torch
class InputPadder:
def __init__(self, img_size, divisor=32):
self.ht, self.wd = img_size
pad_ht = (((self.ht // divisor) + 1) * divisor - self.ht) % divisor
pad_wd = (((self.wd // divisor) + 1) * divisor - self.wd) % divisor
self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, pad_ht - pad_ht // 2]
def pad(self, x):
return torch.nn.functional.pad(x, self._pad, mode="replicate")
def unpad(self, x):
ht, wd = x.shape[-2:]
c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
return x[..., c[0]:c[1], c[2]:c[3]]
def preprocess_cond(x, eps=1e-8):
x_flat = x.flatten(1)
x_mean, x_std = torch.mean(x_flat, dim=-1), torch.std(x_flat, dim=-1) + eps
while len(x_mean.shape) < len(x.shape):
x_mean, x_std = x_mean.unsqueeze(-1), x_std.unsqueeze(-1)
x_norm = (x - x_mean) / x_std
x_mean_0, x_mean_1 = x_mean.chunk(2, dim=0)
x_std_0, x_std_1 = x_std.chunk(2, dim=0)
stats = ((x_mean_0 + x_mean_1) / 2, (x_std_0 + x_std_1) / 2)
return x_norm, stats
def preprocess_frames(frames):
frames = frames / 255.
frame_0, frame_1, gt = frames[:, 0, ...], frames[:, 1, ...], frames[:, 2, ...]
frames = torch.cat((frame_0, frame_1, gt), dim=0)
img_size = [frames.shape[2], frames.shape[3]]
padder = InputPadder(img_size)
return frames, padder, frame_0, frame_1, gt
def one_iter_for_vae(model, frames, is_train=True):
frames, padder, _, _, gt = preprocess_frames(frames)
if not is_train:
with torch.no_grad():
recon, posterior = model(padder.pad(frames))
else:
recon, posterior = model(padder.pad(frames))
recon = padder.unpad(recon.clamp(0., 1.))
return recon, gt, posterior
def one_iter_for_dit(model, vae, frames, transport, sample_fn, vae_mean, vae_scaler, cos_sim_mean, cos_sim_std, is_train=True):
frames, padder, frame_0, frame_1, gt = preprocess_frames(frames)
cond_frames = torch.cat((frame_0, frame_1), dim=0)
difference = ((torch.mean(torch.cosine_similarity(frame_0, frame_1),
dim=[1, 2]) - cos_sim_mean) / cos_sim_std).unsqueeze(1).to(frames.device)
denoise_args = {"cond_frames": padder.pad(cond_frames), "difference": difference}
with torch.no_grad():
posterior, cond_tokens = vae.module.encode(padder.pad(frames))
latent = (posterior.sample() - vae_mean).mul_(vae_scaler)
if is_train:
loss_dict = transport.training_losses(model, latent, **denoise_args)
return loss_dict, latent, cond_tokens, denoise_args
else:
with torch.no_grad():
noise = torch.randn_like(latent).to(frames.device)
samples = sample_fn(noise, model.module.forward, **denoise_args)[-1]
generated = vae.module.decode(samples / vae_scaler + vae_mean, cond_tokens)
generated = padder.unpad(generated.clamp(0., 1.))
return generated