File size: 7,203 Bytes
26555ee
a6e5a9f
26555ee
 
 
 
 
 
 
 
6eb6fe1
004a144
e45b743
 
 
 
 
862185b
19801e5
 
a6e5a9f
004a144
26555ee
 
 
 
 
 
 
 
 
 
8c4bc8e
a6e5a9f
1dade9e
 
a6e5a9f
1dade9e
a6e5a9f
1dade9e
 
 
 
a6e5a9f
 
26555ee
 
 
 
1dade9e
 
a6e5a9f
 
 
 
1dade9e
 
a6e5a9f
 
 
 
1dade9e
 
a6e5a9f
 
 
 
1dade9e
 
a6e5a9f
 
 
 
26555ee
 
a6e5a9f
26555ee
a6e5a9f
26555ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
448d195
 
 
 
 
26555ee
 
 
19801e5
a6e5a9f
26555ee
a6e5a9f
 
26555ee
 
 
 
 
1dade9e
26555ee
68a24d9
26555ee
a6e5a9f
 
 
 
 
 
 
 
 
 
 
 
 
 
26555ee
6352e71
26555ee
 
 
a6e5a9f
26555ee
a6e5a9f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import gradio as gr
from text_to_video import model_t2v_fun, setup_seed
from omegaconf import OmegaConf
import torch
import imageio
import os
import cv2
import pandas as pd
import torchvision
import random
from models import get_models

from pipelines.pipeline_videogen import VideoGenPipeline
from download import find_model
from diffusers.schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler, EulerDiscreteScheduler
from diffusers.models import AutoencoderKL
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection

config_path = "./base/configs/sample.yaml"
args = OmegaConf.load("./base/configs/sample.yaml")
device = "cpu"  # Force CPU usage

css = """
h1 {
  text-align: center;
}
#component-0 {
  max-width: 730px;
  margin: auto;
}
"""

sd_path = args.pretrained_path
unet = get_models(args, sd_path).to(device, dtype=torch.float32)  # Use float32 for CPU
state_dict = find_model("./pretrained_models/lavie_base.pt")
unet.load_state_dict(state_dict)
vae = AutoencoderKL.from_pretrained(sd_path, subfolder="vae", torch_dtype=torch.float32).to(device)  # Use float32 for CPU
tokenizer_one = CLIPTokenizer.from_pretrained(sd_path, subfolder="tokenizer")
text_encoder_one = CLIPTextModel.from_pretrained(sd_path, subfolder="text_encoder", torch_dtype=torch.float32).to(device)  # Use float32 for CPU
unet.eval()
vae.eval()
text_encoder_one.eval()

def infer(prompt, seed_inp, ddim_steps, cfg, infer_type):
    if seed_inp != -1:
        setup_seed(seed_inp)
    else:
        seed_inp = random.choice(range(10000000))
        setup_seed(seed_inp)
    if infer_type == 'ddim':
        scheduler = DDIMScheduler.from_pretrained(sd_path, 
                                               subfolder="scheduler",
                                               beta_start=args.beta_start, 
                                               beta_end=args.beta_end, 
                                               beta_schedule=args.beta_schedule)
    elif infer_type == 'eulerdiscrete':
        scheduler = EulerDiscreteScheduler.from_pretrained(sd_path,
                                               subfolder="scheduler",
                                               beta_start=args.beta_start,
                                               beta_end=args.beta_end,
                                               beta_schedule=args.beta_schedule)
    elif infer_type == 'ddpm':
        scheduler = DDPMScheduler.from_pretrained(sd_path,
                                              subfolder="scheduler",
                                              beta_start=args.beta_start,
                                              beta_end=args.beta_end,
                                              beta_schedule=args.beta_schedule)
    model = VideoGenPipeline(vae=vae, text_encoder=text_encoder_one, tokenizer=tokenizer_one, scheduler=scheduler, unet=unet)
    model.to(device)
    # Disable xformers for CPU
    # if device == "cuda":
    #     model.enable_xformers_memory_efficient_attention()
    videos = model(prompt, video_length=8, height=160, width=256, num_inference_steps=ddim_steps, guidance_scale=cfg).video  # Reduced resolution and length
    if not os.path.exists(args.output_folder):
        os.mkdir(args.output_folder)
    torchvision.io.write_video(args.output_folder + prompt[0:30].replace(' ', '_') + '-' + str(seed_inp) + '-' + str(ddim_steps) + '-' + str(cfg) + '-.mp4', videos[0], fps=4)  # Reduced FPS

    return args.output_folder + prompt[0:30].replace(' ', '_') + '-' + str(seed_inp) + '-' + str(ddim_steps) + '-' + str(cfg) + '-.mp4'

title = """
    <div style="text-align: center; max-width: 700px; margin: 0 auto;">
        <div
        style="
            display: inline-flex;
            align-items: center;
            gap: 0.8rem;
            font-size: 1.75rem;
        "
        >
        <h1 style="font-weight: 900; margin-bottom: 7px; margin-top: 5px;">
            Intern·Vchitect (Text-to-Video)
        </h1>
        </div>
        <p style="margin-bottom: 10px; font-size: 94%">
        Apply Intern·Vchitect to generate a video 
        </p>
    </div>
"""

with gr.Blocks(css='style.css') as demo:
    gr.Markdown("<font color=red size=10><center>LaVie: Text-to-Video generation</center></font>")
    gr.Markdown(
        """<div style="text-align:center">
        [<a href="https://arxiv.org/abs/2309.15103">Arxiv Report</a>] | [<a href="https://vchitect.github.io/LaVie-project/">Project Page</a>] | [<a href="https://github.com/Vchitect/LaVie">Github</a>]</div>
        """
    )
    with gr.Column():
        with gr.Row(elem_id="col-container"):
            with gr.Column():
                prompt = gr.Textbox(value="a corgi walking in the park at sunrise, oil painting style", label="Prompt", placeholder="enter prompt", show_label=True, elem_id="prompt-in", min_width=200, lines=2)
                infer_type = gr.Dropdown(['ddpm', 'ddim', 'eulerdiscrete'], label='infer_type', value='ddim')
                ddim_steps = gr.Slider(label='Steps', minimum=50, maximum=300, value=50, step=1)
                seed_inp = gr.Slider(value=-1, label="seed (for random generation, use -1)", show_label=True, minimum=-1, maximum=2147483647)
                cfg = gr.Number(label="guidance_scale", value=7.5)

            with gr.Column():
                submit_btn = gr.Button("Generate video")
                video_out = gr.Video(label="Video result", elem_id="video-output")

            inputs = [prompt, seed_inp, ddim_steps, cfg, infer_type]
            outputs = [video_out]

        ex = gr.Examples(
            examples=[['a corgi walking in the park at sunrise, oil painting style', 400, 50, 7, 'ddim'],
                      ['a cute teddy bear reading a book in the park, oil painting style, high quality', 700, 50, 7, 'ddim'],
                      ['an epic tornado attacking above a glowing city at night, the tornado is made of smoke, highly detailed', 230, 50, 7, 'ddim'],
                      ['a jar filled with fire, 4K video, 3D rendered, well-rendered', 400, 50, 7, 'ddim'],
                      ['a teddy bear walking in the park, oil painting style, high quality', 400, 50, 7, 'ddim'],
                      ['a teddy bear walking on the street, 2k, high quality', 100, 50, 7, 'ddim'],
                      ['a panda taking a selfie, 2k, high quality', 400, 50, 7, 'ddim'],
                      ['a polar bear playing drum kit in NYC Times Square, 4k, high resolution', 400, 50, 7, 'ddim'],
                      ['jungle river at sunset, ultra quality', 400, 50, 7, 'ddim'],
                      ['a shark swimming in clear Carribean ocean, 2k, high quality', 400, 50, 7, 'ddim'],
                      ['A steam train moving on a mountainside by Vincent van Gogh', 230, 50, 7, 'ddim'],
                      ['a confused grizzly bear in calculus class', 1000, 50, 7, 'ddim']],
            fn=infer,
            inputs=[prompt, seed_inp, ddim_steps, cfg, infer_type],
            outputs=[video_out],
            cache_examples=True,
        )
        ex.dataset.headers = [""]

    submit_btn.click(infer, inputs, outputs)

demo.queue(max_size=12, api_open=False).launch(show_api=False)