Spaces:
Sleeping
Sleeping
| import torch | |
| import cv2 | |
| import torchvision.transforms as transforms | |
| from models.unet_dual_encoder import Embedding_Adapter | |
| from models.diffusion_model import SpaceTimeUnet | |
| import numpy as np | |
| import torchvision.transforms.functional as TVF | |
| from diffusers import AutoencoderKL | |
| from PIL import Image | |
| from transformers import CLIPVisionModel, CLIPProcessor | |
| import torch.nn.functional as F | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| import warnings | |
| warnings.filterwarnings("ignore", category=FutureWarning) | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| frameLimit = 70 | |
| def cosine_beta_schedule(timesteps, start=0.0001, end=0.02): | |
| betas = [] | |
| for i in reversed(range(timesteps)): | |
| T = timesteps - 1 | |
| beta = start + 0.5 * (end - start) * (1 + np.cos((i / T) * np.pi)) | |
| betas.append(beta) | |
| return torch.Tensor(betas) | |
| def get_index_from_list(vals, t, x_shape): | |
| batch_size = t.shape[0] | |
| out = vals.gather(-1, t.cpu()) | |
| return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device) | |
| def forward_diffusion_sample(x_0, t): | |
| noise = torch.randn_like(x_0) | |
| sqrt_alphas_cumprod_t = get_index_from_list(sqrt_alphas_cumprod, t, x_0.shape) | |
| sqrt_one_minus_alphas_cumprod_t = get_index_from_list( | |
| sqrt_one_minus_alphas_cumprod, t, x_0.shape | |
| ) | |
| # mean + variance | |
| return sqrt_alphas_cumprod_t.to(device) * x_0.to(device) \ | |
| + sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device), noise.to(device) | |
| T = 1000 | |
| betas = cosine_beta_schedule(timesteps=T) | |
| # Pre-calculate different terms for closed form | |
| alphas = 1. - betas | |
| alphas_cumprod = torch.cumprod(alphas, axis=0) | |
| alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) | |
| sqrt_recip_alphas = torch.sqrt(1.0 / alphas) | |
| sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) | |
| sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod) | |
| posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) | |
| def get_transform(): | |
| image_transforms = transforms.Compose( | |
| [ | |
| transforms.Resize((640, 512), interpolation=transforms.InterpolationMode.BILINEAR), | |
| transforms.ToTensor(), | |
| ]) | |
| return image_transforms | |
| vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", | |
| subfolder="vae", | |
| revision="ebb811dd71cdc38a204ecbdd6ac5d580f529fd8c") | |
| vae.to(device) | |
| vae.requires_grad_(False) | |
| with torch.no_grad(): | |
| Net = SpaceTimeUnet( | |
| dim = 64, | |
| channels = 4, | |
| dim_mult = (1, 2, 4, 8), | |
| temporal_compression = (False, False, False, True), | |
| self_attns = (False, False, False, True), | |
| condition_on_timestep=True | |
| ).to(device) | |
| adapter = Embedding_Adapter(input_nc=1280, output_nc=1280).to(device) | |
| clip_encoder = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32").to(device) | |
| clip_encoder.requires_grad_(False) | |
| clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| checkpoint = torch.load(hf_hub_download(repo_id="sunjuice/FashionFlow_model", filename="FashionFlow_checkpoint.pth"), map_location=torch.device('cpu')) | |
| Net.load_state_dict(checkpoint['net']) | |
| adapter.load_state_dict(checkpoint['adapter']) | |
| del checkpoint | |
| torch.cuda.empty_cache() | |
| def save_video_frames_as_mp4(frames, fps, save_path): | |
| frame_h, frame_w = frames[0].shape[2:] | |
| fourcc = cv2.VideoWriter_fourcc(*'avc1') | |
| video = cv2.VideoWriter(save_path, fourcc, fps, (frame_w, frame_h)) | |
| frames = frames[0] | |
| for frame in frames: | |
| frame = np.array(TVF.to_pil_image(frame)) | |
| video.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) | |
| video.release() | |
| def VAE_encode(image): | |
| init_latent_dist = vae.encode(image).latent_dist.sample() | |
| init_latent_dist *= 0.18215 | |
| encoded_image = (init_latent_dist).unsqueeze(1) | |
| return encoded_image | |
| def VAE_decode(video, vae_net): | |
| decoded_video = None | |
| for i in range(video.shape[1]): | |
| image = video[:, i, :, :, :] | |
| image = 1 / 0.18215 * image | |
| image = vae_net.decode(image).sample | |
| image = image.clamp(0,1) | |
| if i == 0: | |
| decoded_video = image.unsqueeze(1) | |
| else: | |
| decoded_video = torch.cat([decoded_video, image.unsqueeze(1)], 1) | |
| return decoded_video | |
| def sample_timestep(x, image, t): | |
| betas_t = get_index_from_list(betas, t, x.shape) | |
| sqrt_one_minus_alphas_cumprod_t = get_index_from_list( | |
| sqrt_one_minus_alphas_cumprod, t, x.shape | |
| ) | |
| sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, x.shape) | |
| # Call model (current image - noise prediction) | |
| with torch.cuda.amp.autocast(): | |
| sample_output = Net(x.permute(0, 2, 1, 3, 4), image, timestep=t.float()) | |
| sample_output = sample_output.permute(0, 2, 1, 3, 4) | |
| model_mean = sqrt_recip_alphas_t * ( | |
| x - betas_t * sample_output / sqrt_one_minus_alphas_cumprod_t | |
| ) | |
| if t.item() == 0: | |
| return model_mean | |
| else: | |
| noise = torch.randn_like(x) | |
| posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape) | |
| return model_mean + torch.sqrt(posterior_variance_t) * noise | |
| def tensor2image(tensor): | |
| numpy_image = tensor[0].cpu().detach().numpy() | |
| rescaled_image = (numpy_image * 255).astype(np.uint8) | |
| pil_image = Image.fromarray(rescaled_image.transpose(1, 2, 0)) | |
| return pil_image | |
| def get_image_embedding(input_image): | |
| inputs = clip_processor(images=list(input_image), return_tensors="pt") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| clip_hidden_states = clip_encoder(**inputs).last_hidden_state.to(device) | |
| vae_hidden_states = vae.encode(input_image).latent_dist.sample() * 0.18215 | |
| encoder_hidden_states = adapter(clip_hidden_states, vae_hidden_states) | |
| return encoder_hidden_states | |
| def predict_fn(img_path, progress=gr.Progress()): | |
| img2tensor = get_transform() | |
| image = img2tensor(img_path).unsqueeze(0).to(device) | |
| encoder_hidden_states = get_image_embedding(input_image=image) | |
| encoded_image = VAE_encode(image) | |
| noise_video = torch.randn([1, frameLimit, 4, 80, 64]).to(device) | |
| noise_video[:, 0:1] = encoded_image | |
| with torch.no_grad(): | |
| for i in progress.tqdm(range(0, T)[::-1]): | |
| t = torch.full((1,), i, device=device).long() | |
| noise_video = sample_timestep(noise_video, encoder_hidden_states, t) | |
| noise_video[:, 0:1] = encoded_image | |
| final_video = VAE_decode(noise_video, vae) | |
| save_video_frames_as_mp4(final_video, 25, "result.mp4") | |
| return "result.mp4" | |
| with gr.Blocks() as demo: | |
| with gr.Tab("Image-to-Video"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="pil", label="Input Image") | |
| img_generate = gr.Button("Generate Video") | |
| with gr.Column(): | |
| img_output = gr.Video(label="Generated Video") | |
| gr.Examples( | |
| examples=[ | |
| ['sample/blue.jpg',] | |
| ], | |
| inputs=[image_input], | |
| outputs=[] | |
| ) | |
| img_generate.click( | |
| fn=predict_fn, | |
| inputs=image_input, | |
| outputs=img_output | |
| ) | |
| demo.launch() |