PAQUITA1986 commited on
Commit
fd317e7
·
verified ·
1 Parent(s): af63c17

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +392 -0
app.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/luosiallen/latent-consistency-model
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ from functools import partial
6
+ import os
7
+ import random
8
+ import time
9
+ from omegaconf import OmegaConf
10
+
11
+ import gradio as gr
12
+ import numpy as np
13
+
14
+ try:
15
+ import intel_extension_for_pytorch as ipex
16
+ except:
17
+ pass
18
+
19
+ from utils.lora import collapse_lora, monkeypatch_remove_lora
20
+ from utils.lora_handler import LoraHandler
21
+ from utils.common_utils import load_model_checkpoint
22
+ from utils.utils import instantiate_from_config
23
+ from scheduler.t2v_turbo_scheduler import T2VTurboScheduler
24
+ from pipeline.t2v_turbo_vc2_pipeline import T2VTurboVC2Pipeline
25
+
26
+ import torch
27
+ import torchvision
28
+
29
+ from concurrent.futures import ThreadPoolExecutor
30
+ import uuid
31
+
32
+ DESCRIPTION = """# T2V-Turbo 🚀
33
+
34
+ Our model is distilled from [VideoCrafter2](https://ailab-cvc.github.io/videocrafter2/).
35
+
36
+ T2V-Turbo learns a LoRA on top of the base model by aligning to the reward feedback from [HPSv2.1](https://github.com/tgxs002/HPSv2/tree/master) and [InternVid2 Stage 2 Model](https://huggingface.co/OpenGVLab/InternVideo2-Stage2_1B-224p-f4).
37
+
38
+ T2V-Turbo-v2 optimizes the training techniques by finetuning the full base model and further aligns to [CLIPScore](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K)
39
+
40
+ T2V-Turbo trains on pure WebVid-10M data, whereas T2V-Turbo-v2 carufully optimizes different learning objectives with a mixutre of VidGen-1M and WebVid-10M data.
41
+
42
+ Moreover, T2V-Turbo-v2 supports to distill motion priors from the training videos.
43
+
44
+ [Project page for T2V-Turbo](https://t2v-turbo.github.io) 🥳
45
+
46
+ [Project page for T2V-Turbo-v2](https://t2v-turbo-v2.github.io) 🤓
47
+ """
48
+ if torch.cuda.is_available():
49
+ DESCRIPTION += "\n<p>Running on CUDA 😀</p>"
50
+ elif hasattr(torch, "xpu") and torch.xpu.is_available():
51
+ DESCRIPTION += "\n<p>Running on XPU 🤓</p>"
52
+ else:
53
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
54
+
55
+ MAX_SEED = np.iinfo(np.int32).max
56
+ CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1"
57
+ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE") == "1"
58
+
59
+
60
+ """
61
+ Operation System Options:
62
+ If you are using MacOS, please set the following (device="mps") ;
63
+ If you are using Linux & Windows with Nvidia GPU, please set the device="cuda";
64
+ If you are using Linux & Windows with Intel Arc GPU, please set the device="xpu";
65
+ """
66
+ # device = "mps" # MacOS
67
+ # device = "xpu" # Intel Arc GPU
68
+ device = "cuda" # Linux & Windows
69
+
70
+
71
+ """
72
+ DTYPE Options:
73
+ To reduce GPU memory you can set "DTYPE=torch.float16",
74
+ but image quality might be compromised
75
+ """
76
+ DTYPE = torch.bfloat16
77
+
78
+
79
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
80
+ if randomize_seed:
81
+ seed = random.randint(0, MAX_SEED)
82
+ return seed
83
+
84
+
85
+ def save_video(
86
+ vid_tensor, profile: gr.OAuthProfile | None, metadata: dict, root_path="./", fps=16
87
+ ):
88
+ unique_name = str(uuid.uuid4()) + ".mp4"
89
+ unique_name = os.path.join(root_path, unique_name)
90
+
91
+ video = vid_tensor.detach().cpu()
92
+ video = torch.clamp(video.float(), -1.0, 1.0)
93
+ video = video.permute(1, 0, 2, 3) # t,c,h,w
94
+ video = (video + 1.0) / 2.0
95
+ video = (video * 255).to(torch.uint8).permute(0, 2, 3, 1)
96
+
97
+ torchvision.io.write_video(
98
+ unique_name, video, fps=fps, video_codec="h264", options={"crf": "10"}
99
+ )
100
+ return unique_name
101
+
102
+
103
+ def save_videos(
104
+ video_array, profile: gr.OAuthProfile | None, metadata: dict, fps: int = 16
105
+ ):
106
+ paths = []
107
+ root_path = "./videos/"
108
+ os.makedirs(root_path, exist_ok=True)
109
+ with ThreadPoolExecutor() as executor:
110
+ paths = list(
111
+ executor.map(
112
+ save_video,
113
+ video_array,
114
+ [profile] * len(video_array),
115
+ [metadata] * len(video_array),
116
+ [root_path] * len(video_array),
117
+ [fps] * len(video_array),
118
+ )
119
+ )
120
+ return paths[0]
121
+
122
+
123
+ def generate(
124
+ prompt: str,
125
+ seed: int = 0,
126
+ guidance_scale: float = 7.5,
127
+ percentage: float = 0.3,
128
+ num_inference_steps: int = 4,
129
+ num_frames: int = 16,
130
+ fps: int = 16,
131
+ randomize_seed: bool = False,
132
+ param_dtype="bf16",
133
+ motion_gs: float = 0.05,
134
+ use_motion_cond: bool = False,
135
+ progress=gr.Progress(track_tqdm=True),
136
+ profile: gr.OAuthProfile | None = None,
137
+ ):
138
+ seed = randomize_seed_fn(seed, randomize_seed)
139
+ torch.manual_seed(seed)
140
+
141
+ if param_dtype == "bf16":
142
+ dtype = torch.bfloat16
143
+ unet.dtype = torch.bfloat16
144
+ elif param_dtype == "fp16":
145
+ dtype = torch.float16
146
+ unet.dtype = torch.float16
147
+ elif param_dtype == "fp32":
148
+ dtype = torch.float32
149
+ unet.dtype = torch.float32
150
+ else:
151
+ raise ValueError(f"Unknown dtype: {param_dtype}")
152
+
153
+ pipeline.unet.to(device, dtype)
154
+ pipeline.text_encoder.to(device, dtype)
155
+ pipeline.vae.to(device, dtype)
156
+ pipeline.to(device, dtype)
157
+
158
+ start_time = time.time()
159
+
160
+ result = pipeline(
161
+ prompt=prompt,
162
+ frames=num_frames,
163
+ fps=fps,
164
+ guidance_scale=guidance_scale,
165
+ motion_gs=motion_gs,
166
+ use_motion_cond=use_motion_cond,
167
+ percentage=percentage,
168
+ num_inference_steps=num_inference_steps,
169
+ lcm_origin_steps=200,
170
+ num_videos_per_prompt=1,
171
+ )
172
+ paths = save_videos(
173
+ result,
174
+ profile,
175
+ metadata={
176
+ "prompt": prompt,
177
+ "seed": seed,
178
+ "guidance_scale": guidance_scale,
179
+ "num_inference_steps": num_inference_steps,
180
+ },
181
+ fps=fps,
182
+ )
183
+ print(time.time() - start_time)
184
+ return paths, seed
185
+
186
+
187
+ examples = [
188
+ "An astronaut riding a horse.",
189
+ "Darth vader surfing in waves.",
190
+ "Robot dancing in times square.",
191
+ "Clown fish swimming through the coral reef.",
192
+ "Pikachu snowboarding.",
193
+ "With the style of van gogh, A young couple dances under the moonlight by the lake.",
194
+ "A young woman with glasses is jogging in the park wearing a pink headband.",
195
+ "Impressionist style, a yellow rubber duck floating on the wave on the sunset",
196
+ "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
197
+ "With the style of low-poly game art, A majestic, white horse gallops gracefully across a moonlit beach.",
198
+ ]
199
+
200
+
201
+ if __name__ == "__main__":
202
+ # Add model name as parameter
203
+ parser = argparse.ArgumentParser(description="Gradio demo for T2V-Turbo.")
204
+ parser.add_argument(
205
+ "--unet_dir",
206
+ type=str,
207
+ default="output/vlcm_vc2_mixed_vid_gen_128k_bs3_percen_0p2_mgs_max_0p1/checkpoint-10000/unet.pt",
208
+ help="Directory of the UNet model",
209
+ )
210
+ parser.add_argument(
211
+ "--base_model_dir",
212
+ type=str,
213
+ default="model_cache/VideoCrafter2_model.ckpt",
214
+ help="Directory of the VideoCrafter2 checkpoint.",
215
+ )
216
+ parser.add_argument(
217
+ "--version",
218
+ required=True,
219
+ choices=["v1", "v2"],
220
+ help="Whether to use motion condition or not.",
221
+ )
222
+ parser.add_argument(
223
+ "--motion_gs",
224
+ default=0.05,
225
+ type=float,
226
+ help="Guidance scale for motion condition.",
227
+ )
228
+
229
+ args = parser.parse_args()
230
+
231
+ config = OmegaConf.load("configs/inference_t2v_512_v2.0.yaml")
232
+ model_config = config.pop("model", OmegaConf.create())
233
+ pretrained_t2v = instantiate_from_config(model_config)
234
+ pretrained_t2v = load_model_checkpoint(pretrained_t2v, args.base_model_dir)
235
+
236
+ unet_config = model_config["params"]["unet_config"]
237
+ unet_config["params"]["use_checkpoint"] = False
238
+ unet_config["params"]["time_cond_proj_dim"] = 256
239
+
240
+ if args.version == "v2":
241
+ unet_config["params"]["motion_cond_proj_dim"] = 256
242
+ unet = instantiate_from_config(unet_config)
243
+
244
+ if "lora" in args.unet_dir:
245
+ unet.load_state_dict(
246
+ pretrained_t2v.model.diffusion_model.state_dict(), strict=False
247
+ )
248
+
249
+ use_unet_lora = True
250
+ lora_manager = LoraHandler(
251
+ version="cloneofsimo",
252
+ use_unet_lora=use_unet_lora,
253
+ save_for_webui=True,
254
+ unet_replace_modules=["UNetModel"],
255
+ )
256
+ lora_manager.add_lora_to_model(
257
+ use_unet_lora,
258
+ unet,
259
+ lora_manager.unet_replace_modules,
260
+ lora_path=args.unet_dir,
261
+ dropout=0.1,
262
+ r=64,
263
+ )
264
+ collapse_lora(unet, lora_manager.unet_replace_modules)
265
+ monkeypatch_remove_lora(unet)
266
+ else:
267
+ unet.load_state_dict(torch.load(args.unet_dir, map_location=device))
268
+
269
+ unet.eval()
270
+ pretrained_t2v.model.diffusion_model = unet
271
+ scheduler = T2VTurboScheduler(
272
+ linear_start=model_config["params"]["linear_start"],
273
+ linear_end=model_config["params"]["linear_end"],
274
+ )
275
+ pipeline = T2VTurboVC2Pipeline(pretrained_t2v, scheduler, model_config)
276
+
277
+ pipeline.to(device)
278
+
279
+ with gr.Blocks(css="style.css") as demo:
280
+ gr.Markdown(DESCRIPTION)
281
+ gr.DuplicateButton(
282
+ value="Duplicate Space for private use",
283
+ elem_id="duplicate-button",
284
+ visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
285
+ )
286
+ with gr.Group():
287
+ with gr.Row():
288
+ prompt = gr.Text(
289
+ label="Prompt",
290
+ show_label=False,
291
+ max_lines=1,
292
+ placeholder="Enter your prompt",
293
+ container=False,
294
+ )
295
+ run_button = gr.Button("Run", scale=0)
296
+ result_video = gr.Video(
297
+ label="Generated Video", interactive=False, autoplay=True
298
+ )
299
+ with gr.Accordion("Advanced options", open=False):
300
+ seed = gr.Slider(
301
+ label="Seed",
302
+ minimum=0,
303
+ maximum=MAX_SEED,
304
+ step=1,
305
+ value=0,
306
+ randomize=True,
307
+ )
308
+ randomize_seed = gr.Checkbox(label="Randomize seed across runs", value=True)
309
+ dtype_choices = ["bf16", "fp16", "fp32"]
310
+ param_dtype = gr.Radio(
311
+ dtype_choices,
312
+ label="torch.dtype",
313
+ value=dtype_choices[0],
314
+ interactive=True,
315
+ info="To save GPU memory, use fp16 or bf16. For better quality, use fp32.",
316
+ )
317
+ with gr.Row():
318
+ percentage = gr.Slider(
319
+ label="Percentage of steps to apply motion guidance (v2 w/ MG only)",
320
+ minimum=0.0,
321
+ maximum=0.5,
322
+ step=0.05,
323
+ value=0.3,
324
+ )
325
+
326
+ with gr.Row():
327
+ guidance_scale = gr.Slider(
328
+ label="Guidance scale for base",
329
+ minimum=2,
330
+ maximum=14,
331
+ step=0.1,
332
+ value=7.5,
333
+ )
334
+ num_inference_steps = gr.Slider(
335
+ label="Number of inference steps for base",
336
+ minimum=4,
337
+ maximum=50,
338
+ step=1,
339
+ value=8,
340
+ )
341
+ with gr.Row():
342
+ num_frames = gr.Slider(
343
+ label="Number of Video Frames",
344
+ minimum=16,
345
+ maximum=48,
346
+ step=8,
347
+ value=16,
348
+ )
349
+ fps = gr.Slider(
350
+ label="FPS",
351
+ minimum=8,
352
+ maximum=32,
353
+ step=4,
354
+ value=8,
355
+ )
356
+
357
+ use_motion_cond = args.version == "v1"
358
+ generate = partial(
359
+ generate, use_motion_cond=use_motion_cond, motion_gs=args.motion_gs
360
+ )
361
+ gr.Examples(
362
+ examples=examples,
363
+ inputs=prompt,
364
+ outputs=result_video,
365
+ fn=generate,
366
+ cache_examples=CACHE_EXAMPLES,
367
+ )
368
+
369
+ gr.on(
370
+ triggers=[
371
+ prompt.submit,
372
+ run_button.click,
373
+ ],
374
+ fn=generate,
375
+ inputs=[
376
+ prompt,
377
+ seed,
378
+ guidance_scale,
379
+ percentage,
380
+ num_inference_steps,
381
+ num_frames,
382
+ fps,
383
+ randomize_seed,
384
+ param_dtype,
385
+ ],
386
+ outputs=[result_video, seed],
387
+ api_name="run",
388
+ )
389
+
390
+ demo.queue(api_open=False)
391
+ # demo.queue(max_size=20).launch()
392
+ demo.launch()