File size: 7,291 Bytes
f075308
 
 
 
 
 
 
 
 
 
 
 
 
baec227
 
f075308
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ff190a
f075308
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f0c94f0
f075308
 
 
 
 
 
 
95ba07b
f075308
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95ba07b
 
f075308
 
 
 
 
 
 
 
 
 
 
 
 
95ba07b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import torch
import cv2
import torchvision.transforms as transforms
from models.unet_dual_encoder import Embedding_Adapter
from models.diffusion_model import SpaceTimeUnet
import numpy as np
import torchvision.transforms.functional as TVF
from diffusers import AutoencoderKL
from PIL import Image
from transformers import CLIPVisionModel, CLIPProcessor
import torch.nn.functional as F
import gradio as gr
from huggingface_hub import hf_hub_download
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
frameLimit = 70

def cosine_beta_schedule(timesteps, start=0.0001, end=0.02):
    betas = []
    for i in reversed(range(timesteps)):
        T = timesteps - 1
        beta = start + 0.5 * (end - start) * (1 + np.cos((i / T) * np.pi))
        betas.append(beta)
    return torch.Tensor(betas)


def get_index_from_list(vals, t, x_shape):
    batch_size = t.shape[0]
    out = vals.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

def forward_diffusion_sample(x_0, t):
    noise = torch.randn_like(x_0)
    sqrt_alphas_cumprod_t = get_index_from_list(sqrt_alphas_cumprod, t, x_0.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x_0.shape
    )
    # mean + variance
    return sqrt_alphas_cumprod_t.to(device) * x_0.to(device) \
    + sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device), noise.to(device)

T = 1000
betas = cosine_beta_schedule(timesteps=T)
# Pre-calculate different terms for closed form
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

def get_transform():
    image_transforms = transforms.Compose(
        [
        transforms.Resize((640, 512), interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.ToTensor(),
        ])
    return image_transforms

vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4",
            subfolder="vae",
            revision="ebb811dd71cdc38a204ecbdd6ac5d580f529fd8c")
vae.to(device)
vae.requires_grad_(False)

with torch.no_grad():
    Net = SpaceTimeUnet(
        dim = 64,
        channels = 4,
        dim_mult = (1, 2, 4, 8),
        temporal_compression = (False, False, False, True),
        self_attns = (False, False, False, True),
        condition_on_timestep=True
    ).to(device)
adapter = Embedding_Adapter(input_nc=1280, output_nc=1280).to(device)

clip_encoder = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
clip_encoder.requires_grad_(False)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

checkpoint = torch.load(hf_hub_download(repo_id="sunjuice/FashionFlow_model", filename="FashionFlow_checkpoint.pth"), map_location=torch.device('cpu'))
Net.load_state_dict(checkpoint['net'])
adapter.load_state_dict(checkpoint['adapter'])
del checkpoint
torch.cuda.empty_cache()

def save_video_frames_as_mp4(frames, fps, save_path):
    frame_h, frame_w = frames[0].shape[2:]
    fourcc = cv2.VideoWriter_fourcc(*'avc1')
    video = cv2.VideoWriter(save_path, fourcc, fps, (frame_w, frame_h))
    frames = frames[0]
    for frame in frames:
        frame = np.array(TVF.to_pil_image(frame))
        video.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
    video.release()

@torch.no_grad()
def VAE_encode(image):
    init_latent_dist = vae.encode(image).latent_dist.sample()
    init_latent_dist *= 0.18215
    encoded_image = (init_latent_dist).unsqueeze(1)
    return encoded_image

@torch.no_grad()
def VAE_decode(video, vae_net):
    decoded_video = None
    for i in range(video.shape[1]):
        image = video[:, i, :, :, :]
        image = 1 / 0.18215 * image
        image = vae_net.decode(image).sample
        image = image.clamp(0,1)
        if i == 0:
            decoded_video = image.unsqueeze(1)
        else:
            decoded_video = torch.cat([decoded_video, image.unsqueeze(1)], 1)
    return decoded_video

@torch.no_grad()
def sample_timestep(x, image, t):
    betas_t = get_index_from_list(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, x.shape)

    # Call model (current image - noise prediction)
    with torch.cuda.amp.autocast():
        sample_output = Net(x.permute(0, 2, 1, 3, 4), image, timestep=t.float())

    sample_output = sample_output.permute(0, 2, 1, 3, 4)
    model_mean = sqrt_recip_alphas_t * (
            x - betas_t * sample_output / sqrt_one_minus_alphas_cumprod_t
    )
    if t.item() == 0:
        return model_mean
    else:
        noise = torch.randn_like(x)
        posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape)
        return model_mean + torch.sqrt(posterior_variance_t) * noise

def tensor2image(tensor):
    numpy_image = tensor[0].cpu().detach().numpy()
    rescaled_image = (numpy_image * 255).astype(np.uint8)
    pil_image = Image.fromarray(rescaled_image.transpose(1, 2, 0))
    return pil_image

@torch.no_grad()
def get_image_embedding(input_image):
    inputs = clip_processor(images=list(input_image), return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    clip_hidden_states = clip_encoder(**inputs).last_hidden_state.to(device)
    vae_hidden_states = vae.encode(input_image).latent_dist.sample() * 0.18215
    encoder_hidden_states = adapter(clip_hidden_states, vae_hidden_states)
    return encoder_hidden_states

def predict_fn(img_path, progress=gr.Progress()):
    img2tensor = get_transform()
    image = img2tensor(img_path).unsqueeze(0).to(device)
    encoder_hidden_states = get_image_embedding(input_image=image)
    encoded_image = VAE_encode(image)
    noise_video = torch.randn([1, frameLimit, 4, 80, 64]).to(device)
    noise_video[:, 0:1] = encoded_image
    with torch.no_grad():
        for i in progress.tqdm(range(0, T)[::-1]):
            t = torch.full((1,), i, device=device).long()
            noise_video = sample_timestep(noise_video, encoder_hidden_states, t)
            noise_video[:, 0:1] = encoded_image
        final_video = VAE_decode(noise_video, vae)
        save_video_frames_as_mp4(final_video, 25, "result.mp4")
    return "result.mp4"

with gr.Blocks() as demo:
    with gr.Tab("Image-to-Video"):
        with gr.Row():
            with gr.Column():
                image_input = gr.Image(type="pil", label="Input Image")
                img_generate = gr.Button("Generate Video")
            with gr.Column():
                img_output = gr.Video(label="Generated Video")
        gr.Examples(
            examples=[
                ['sample/blue.jpg',]
            ],
            inputs=[image_input],
            outputs=[]
        )

        img_generate.click(
            fn=predict_fn,
            inputs=image_input,
            outputs=img_output
        )

demo.launch()