import os os.environ["CUDA_VISIBLE_DEVICES"]="1" from transformers import CLIPVisionModel, CLIPProcessor import torch.nn as nn import torch import torch.utils.data as data import os.path as osp import cv2 import torchvision.transforms as transforms import torch.optim as optim from tensorboardX import SummaryWriter import argparse import numpy as np import torchvision.transforms.functional as TVF import torch.nn.functional as F from models.unet_dual_encoder import Embedding_Adapter from distributed import (get_rank, synchronize) from diffusers import AutoencoderKL from models.diffusion_model import SpaceTimeUnet parser = argparse.ArgumentParser(description="Configuration of the training script.") parser.add_argument("--local_rank", type=int, default=0, help="local rank for distributed training") parser.add_argument('--dataset', default="fashion_dataset/train", help="Path to the dataset") parser.add_argument('--dataset_vae', default="fashion_dataset_tensor", help="Path to the tensors of latent space") parser.add_argument('--output_dir', default="checkpoint", help="Path to save the checkpoints") args = parser.parse_args() args = parser.parse_args() torch.distributed.init_process_group(backend="nccl", init_method="env://") torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) synchronize() frameLimit = 70 if get_rank() == 0: writer = SummaryWriter('video_progress') def cosine_beta_schedule(timesteps, start=0.0001, end=0.02): betas = [] for i in reversed(range(timesteps)): T = timesteps - 1 beta = start + 0.5 * (end - start) * (1 + np.cos((i / T) * np.pi)) betas.append(beta) return torch.Tensor(betas) def get_index_from_list(vals, t, x_shape): batch_size = t.shape[0] out = vals.gather(-1, t.cpu()) return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device) def forward_diffusion_sample(x_0, t): noise = torch.randn_like(x_0) sqrt_alphas_cumprod_t = get_index_from_list(sqrt_alphas_cumprod, t, x_0.shape) sqrt_one_minus_alphas_cumprod_t = get_index_from_list( sqrt_one_minus_alphas_cumprod, t, x_0.shape ) # mean + variance return sqrt_alphas_cumprod_t.to(device) * x_0.to(device) \ + sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device), noise.to(device) T = 1000 betas = cosine_beta_schedule(timesteps=T) # Pre-calculate different terms for closed form alphas = 1. - betas alphas_cumprod = torch.cumprod(alphas, axis=0) alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) sqrt_recip_alphas = torch.sqrt(1.0 / alphas) sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod) posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) def get_transform(): image_transforms = transforms.Compose( [ transforms.Resize((640, 512), interpolation=transforms.InterpolationMode.BILINEAR), transforms.ToTensor(), ]) return image_transforms class VideoFrameDataset(data.Dataset): def __init__(self): super(VideoFrameDataset, self).__init__() self.path = osp.join(args.dataset) self.vae_path = osp.join(args.dataset_vae) self.video_names = os.listdir(self.path) self.transform = get_transform() def __getitem__(self, index): video_name = self.video_names[index] inputImage = torch.load(osp.join(self.vae_path, video_name[:-4]+"_image.pt"), map_location='cpu') restOfVideo = torch.load(osp.join(self.vae_path, video_name[:-4]+".pt"), map_location='cpu') return {'image': inputImage, 'video': restOfVideo} def __len__(self): return len(self.video_names) vae = AutoencoderKL.from_pretrained( "CompVis/stable-diffusion-v1-4", subfolder="vae", revision="ebb811dd71cdc38a204ecbdd6ac5d580f529fd8c" ).to(device) vae.requires_grad_(False) @torch.no_grad() def VAE_encode(image): init_latent_dist = vae.encode(image).latent_dist.sample() init_latent_dist *= 0.18215 encoded_image = (init_latent_dist).unsqueeze(1) return encoded_image Net = SpaceTimeUnet( dim = 64, channels = 4, dim_mult = (1, 2, 4, 8), temporal_compression = (False, False, False, True), self_attns = (False, False, False, True), condition_on_timestep = True, ).to(device) adapter = Embedding_Adapter(input_nc=1280, output_nc=1280).to(device) clip_encoder = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32").cuda() clip_encoder.requires_grad_(False) clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") parameters = list(Net.parameters()) + list(adapter.parameters()) optimizerG = optim.AdamW(parameters, lr=0.0001, weight_decay=0.01) Net = nn.parallel.DistributedDataParallel( Net, device_ids=[args.local_rank], output_device=args.local_rank, broadcast_buffers=False) adapter = nn.parallel.DistributedDataParallel( adapter, device_ids=[args.local_rank], output_device=args.local_rank, broadcast_buffers=False) def data_sampler(dataset, shuffle, distributed): if distributed: return data.distributed.DistributedSampler(dataset) if shuffle: return data.RandomSampler(dataset) else: return data.SequentialSampler(dataset) train_dataset = VideoFrameDataset() sampler = data_sampler(train_dataset, shuffle=True, distributed=True) batch = 2 train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=batch, sampler=sampler, num_workers=1, drop_last=True) def save_video_frames_as_mp4(frames, fps, save_path): frame_h, frame_w = frames[0].shape[2:] fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v') video = cv2.VideoWriter(save_path, fourcc, fps, (frame_w, frame_h)) frames = frames[0] for frame in frames: frame = np.array(TVF.to_pil_image(frame)) video.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) video.release() mseloss = torch.nn.MSELoss(reduction="mean") def get_loss(input_image, latent_video): timesteps = torch.randint(0, T, (batch,), device=device) timesteps = timesteps.long() initial_frame_latent_video = latent_video[:, 0:1].clone().detach() # [b, f, c, h, w] x_noisy, noise = forward_diffusion_sample(latent_video, timesteps) x_noisy[:, 0:1] = initial_frame_latent_video noise[:, 0:1] = torch.zeros(initial_frame_latent_video.shape) x_noisy = x_noisy.permute(0, 2, 1, 3, 4) inputs = clip_processor(images=list(input_image), return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} clip_hidden_states = clip_encoder(**inputs).last_hidden_state.to(device) vae_hidden_states = vae.encode(input_image).latent_dist.sample() * 0.18215 encoder_hidden_states = adapter(clip_hidden_states, vae_hidden_states) noise_pred = Net(x_noisy, encoder_hidden_states, timestep=timesteps.float()) noise_pred = noise_pred.permute(0, 2, 1, 3, 4) loss = 0.0 for i in range(frameLimit): loss += mseloss(noise_pred[:, i, :, :, :], noise[:, i, :, :, :]) return loss @torch.no_grad() def VAE_decode(video): decoded_video = None for i in range(video.shape[1]): image = video[:, i, :, :, :] image = 1 / 0.18215 * image if i == 0: image = vae.decode(image).sample image = (image / 2 + 0.5).clamp(0, 1) decoded_video = image.unsqueeze(1) else: image = vae.decode(image).sample image = (image / 2 + 0.5).clamp(0, 1) decoded_video = torch.cat([decoded_video, image.unsqueeze(1)], 1) return decoded_video @torch.no_grad() def sample_timestep(x, image, t): betas_t = get_index_from_list(betas, t, x.shape) sqrt_one_minus_alphas_cumprod_t = get_index_from_list( sqrt_one_minus_alphas_cumprod, t, x.shape ) sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, x.shape) # Call model (current image - noise prediction) with torch.cuda.amp.autocast(): sample_output = Net(x.permute(0, 2, 1, 3, 4), image, timestep=t.float()) sample_output = sample_output.permute(0, 2, 1, 3, 4) model_mean = sqrt_recip_alphas_t * ( x - betas_t * sample_output / sqrt_one_minus_alphas_cumprod_t ) if t.item() == 0: return model_mean else: noise = torch.randn_like(x) posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape) return model_mean + torch.sqrt(posterior_variance_t) * noise @torch.no_grad() def get_image_embedding(input_image): inputs = clip_processor(images=list(input_image), return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} clip_hidden_states = clip_encoder(**inputs).last_hidden_state.to(device) vae_hidden_states = vae.encode(input_image).latent_dist.sample() * 0.18215 encoder_hidden_states = adapter(clip_hidden_states, vae_hidden_states) return encoder_hidden_states if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) if not os.path.exists('training_sample'): os.makedirs('training_sample') step = 0 for epoch in range(2500): Net.train() adapter.train() for data in train_dataloader: step += 1 vae_video = data['video'].to(device=device) # [b, f, c, h, w] image = data['image'].to(device=device) loss = get_loss(input_image=image, latent_video=vae_video) optimizerG.zero_grad() loss.backward() optimizerG.step() if get_rank() == 0 and epoch % 40 == 0: writer.add_scalar('loss', loss, step) if get_rank() == 0 and epoch % 100 == 0: torch.save( { 'net': Net.module.state_dict(), 'adapter': adapter.module.state_dict(), 'opt': optimizerG.state_dict() }, args.output_dir + "/model_" + str(epoch) + "_" + str(step) + ".pth") if get_rank() == 0 and epoch % 100 == 0: noise_video = torch.randn([1, frameLimit, 4, 80, 64]).to(device) encoder_hidden_states = get_image_embedding(input_image=image[0].unsqueeze(0)) encoded_image = VAE_encode(image[0].unsqueeze(0)) noise_video[:, 0:1] = encoded_image with torch.no_grad(): for i in range(0, T)[::-1]: t = torch.full((1,), i, device=device).long() noise_video = sample_timestep(noise_video, encoder_hidden_states, t) noise_video[:, 0:1] = encoded_image final_video = VAE_decode(noise_video) writer.add_image('input image', image[0], step) writer.add_video('video', final_video, step) save_video_frames_as_mp4(final_video, 25, "training_sample/video"+str(epoch)+".mp4") if get_rank() == 0: torch.save({ 'net': Net.module.state_dict(), 'adapter': adapter.module.state_dict() }, args.output_dir + "/vae_clip_e100.pth")