import torch import cv2 import torchvision.transforms as transforms from models.unet_dual_encoder import Embedding_Adapter from models.diffusion_model import SpaceTimeUnet import numpy as np import torchvision.transforms.functional as TVF from diffusers import AutoencoderKL from PIL import Image from transformers import CLIPVisionModel, CLIPProcessor import torch.nn.functional as F import gradio as gr from huggingface_hub import hf_hub_download import warnings warnings.filterwarnings("ignore", category=FutureWarning) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') frameLimit = 70 def cosine_beta_schedule(timesteps, start=0.0001, end=0.02): betas = [] for i in reversed(range(timesteps)): T = timesteps - 1 beta = start + 0.5 * (end - start) * (1 + np.cos((i / T) * np.pi)) betas.append(beta) return torch.Tensor(betas) def get_index_from_list(vals, t, x_shape): batch_size = t.shape[0] out = vals.gather(-1, t.cpu()) return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device) def forward_diffusion_sample(x_0, t): noise = torch.randn_like(x_0) sqrt_alphas_cumprod_t = get_index_from_list(sqrt_alphas_cumprod, t, x_0.shape) sqrt_one_minus_alphas_cumprod_t = get_index_from_list( sqrt_one_minus_alphas_cumprod, t, x_0.shape ) # mean + variance return sqrt_alphas_cumprod_t.to(device) * x_0.to(device) \ + sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device), noise.to(device) T = 1000 betas = cosine_beta_schedule(timesteps=T) # Pre-calculate different terms for closed form alphas = 1. - betas alphas_cumprod = torch.cumprod(alphas, axis=0) alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) sqrt_recip_alphas = torch.sqrt(1.0 / alphas) sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod) posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) def get_transform(): image_transforms = transforms.Compose( [ transforms.Resize((640, 512), interpolation=transforms.InterpolationMode.BILINEAR), transforms.ToTensor(), ]) return image_transforms vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae", revision="ebb811dd71cdc38a204ecbdd6ac5d580f529fd8c") vae.to(device) vae.requires_grad_(False) with torch.no_grad(): Net = SpaceTimeUnet( dim = 64, channels = 4, dim_mult = (1, 2, 4, 8), temporal_compression = (False, False, False, True), self_attns = (False, False, False, True), condition_on_timestep=True ).to(device) adapter = Embedding_Adapter(input_nc=1280, output_nc=1280).to(device) clip_encoder = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32").to(device) clip_encoder.requires_grad_(False) clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") checkpoint = torch.load(hf_hub_download(repo_id="sunjuice/FashionFlow_model", filename="FashionFlow_checkpoint.pth"), map_location=torch.device('cpu')) Net.load_state_dict(checkpoint['net']) adapter.load_state_dict(checkpoint['adapter']) del checkpoint torch.cuda.empty_cache() def save_video_frames_as_mp4(frames, fps, save_path): frame_h, frame_w = frames[0].shape[2:] fourcc = cv2.VideoWriter_fourcc(*'avc1') video = cv2.VideoWriter(save_path, fourcc, fps, (frame_w, frame_h)) frames = frames[0] for frame in frames: frame = np.array(TVF.to_pil_image(frame)) video.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) video.release() @torch.no_grad() def VAE_encode(image): init_latent_dist = vae.encode(image).latent_dist.sample() init_latent_dist *= 0.18215 encoded_image = (init_latent_dist).unsqueeze(1) return encoded_image @torch.no_grad() def VAE_decode(video, vae_net): decoded_video = None for i in range(video.shape[1]): image = video[:, i, :, :, :] image = 1 / 0.18215 * image image = vae_net.decode(image).sample image = image.clamp(0,1) if i == 0: decoded_video = image.unsqueeze(1) else: decoded_video = torch.cat([decoded_video, image.unsqueeze(1)], 1) return decoded_video @torch.no_grad() def sample_timestep(x, image, t): betas_t = get_index_from_list(betas, t, x.shape) sqrt_one_minus_alphas_cumprod_t = get_index_from_list( sqrt_one_minus_alphas_cumprod, t, x.shape ) sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, x.shape) # Call model (current image - noise prediction) with torch.cuda.amp.autocast(): sample_output = Net(x.permute(0, 2, 1, 3, 4), image, timestep=t.float()) sample_output = sample_output.permute(0, 2, 1, 3, 4) model_mean = sqrt_recip_alphas_t * ( x - betas_t * sample_output / sqrt_one_minus_alphas_cumprod_t ) if t.item() == 0: return model_mean else: noise = torch.randn_like(x) posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape) return model_mean + torch.sqrt(posterior_variance_t) * noise def tensor2image(tensor): numpy_image = tensor[0].cpu().detach().numpy() rescaled_image = (numpy_image * 255).astype(np.uint8) pil_image = Image.fromarray(rescaled_image.transpose(1, 2, 0)) return pil_image @torch.no_grad() def get_image_embedding(input_image): inputs = clip_processor(images=list(input_image), return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} clip_hidden_states = clip_encoder(**inputs).last_hidden_state.to(device) vae_hidden_states = vae.encode(input_image).latent_dist.sample() * 0.18215 encoder_hidden_states = adapter(clip_hidden_states, vae_hidden_states) return encoder_hidden_states def predict_fn(img_path, progress=gr.Progress()): img2tensor = get_transform() image = img2tensor(img_path).unsqueeze(0).to(device) encoder_hidden_states = get_image_embedding(input_image=image) encoded_image = VAE_encode(image) noise_video = torch.randn([1, frameLimit, 4, 80, 64]).to(device) noise_video[:, 0:1] = encoded_image with torch.no_grad(): for i in progress.tqdm(range(0, T)[::-1]): t = torch.full((1,), i, device=device).long() noise_video = sample_timestep(noise_video, encoder_hidden_states, t) noise_video[:, 0:1] = encoded_image final_video = VAE_decode(noise_video, vae) save_video_frames_as_mp4(final_video, 25, "result.mp4") return "result.mp4" with gr.Blocks() as demo: with gr.Tab("Image-to-Video"): with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="Input Image") img_generate = gr.Button("Generate Video") with gr.Column(): img_output = gr.Video(label="Generated Video") gr.Examples( examples=[ ['sample/blue.jpg',] ], inputs=[image_input], outputs=[] ) img_generate.click( fn=predict_fn, inputs=image_input, outputs=img_output ) demo.launch()