File size: 9,887 Bytes
abb49c0
2d9a6f9
abb49c0
 
b991f7f
abb49c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a03985
 
 
 
 
 
 
 
abb49c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a381f1e
abb49c0
 
 
fd56559
 
 
 
 
 
 
 
 
 
 
abb49c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd56559
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
abb49c0
 
 
 
 
 
 
 
 
fd56559
abb49c0
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
import spaces
from huggingface_hub import snapshot_download, hf_hub_download
import os
import subprocess
import importlib, site

# Re-discover all .pth/.egg-link files
for sitedir in site.getsitepackages():
    site.addsitedir(sitedir)

# Clear caches so importlib will pick up new modules
importlib.invalidate_caches()

def sh(cmd): subprocess.check_call(cmd, shell=True)

flash_attention_installed = False

try:
    print("Attempting to download and install FlashAttention wheel...")
    flash_attention_wheel = hf_hub_download(
            repo_id="alexnasa/flash-attn-3",
            repo_type="model",
            filename="128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl",
        )

    sh(f"pip install {flash_attention_wheel}")

    # tell Python to re-scan site-packages now that the egg-link exists
    import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches()

    flash_attention_installed = True
    print("FlashAttention installed successfully.")

except Exception as e:
    print(f"⚠️ Could not install FlashAttention: {e}")
    print("Continuing without FlashAttention...")

import torch
print(f"Torch version: {torch.__version__}")
print(f"FlashAttention available: {flash_attention_installed}")

import gradio as gr
import argparse
from ovi.ovi_fusion_engine import OviFusionEngine, DEFAULT_CONFIG
from diffusers import FluxPipeline
import tempfile
from ovi.utils.io_utils import save_video
from ovi.utils.processing_utils import clean_text, scale_hw_to_area_divisible

# ----------------------------
# Parse CLI Args
# ----------------------------
parser = argparse.ArgumentParser(description="Ovi Joint Video + Audio Gradio Demo")
parser.add_argument(
    "--use_image_gen",
    action="store_true",
    help="Enable image generation UI with FluxPipeline"
)
parser.add_argument(
    "--cpu_offload",
    action="store_true",
    help="Enable CPU offload for both OviFusionEngine and FluxPipeline"
)
args = parser.parse_args()

ckpt_dir = "./ckpts"

# Wan2.2
wan_dir = os.path.join(ckpt_dir, "Wan2.2-TI2V-5B")
snapshot_download(
    repo_id="Wan-AI/Wan2.2-TI2V-5B",
    local_dir=wan_dir,
    allow_patterns=[
        "google/*",
        "models_t5_umt5-xxl-enc-bf16.pth",
        "Wan2.2_VAE.pth"
    ]
)

# MMAudio
mm_audio_dir = os.path.join(ckpt_dir, "MMAudio")
snapshot_download(
    repo_id="hkchengrex/MMAudio",
    local_dir=mm_audio_dir,
    allow_patterns=[
        "ext_weights/best_netG.pt",
        "ext_weights/v1-16.pth"
    ]
)

ovi_dir = os.path.join(ckpt_dir, "Ovi")
snapshot_download(
    repo_id="chetwinlow1/Ovi",
    local_dir=ovi_dir,
    allow_patterns=[
        "model.safetensors"
    ]
)

# Initialize OviFusionEngine
enable_cpu_offload = args.cpu_offload or args.use_image_gen
use_image_gen = args.use_image_gen
print(f"loading model... {enable_cpu_offload=}, {use_image_gen=} for gradio demo")
DEFAULT_CONFIG['cpu_offload'] = enable_cpu_offload # always use cpu offload if image generation is enabled
DEFAULT_CONFIG['mode'] = "t2v"  # hardcoded since it is always cpu offloaded
ovi_engine = OviFusionEngine()
flux_model = None
if use_image_gen:
    flux_model = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-Krea-dev", torch_dtype=torch.bfloat16)
    flux_model.enable_model_cpu_offload() #save some VRAM by offloading the model to CPU. Remove this if you have enough GPU VRAM
print("loaded model")


@spaces.GPU(duration=120)
def generate_video(
    text_prompt,
    image,
    sample_steps = 50,
    video_frame_height = 992,
    video_frame_width = 512,
    video_seed = 100,
    solver_name = "unipc",
    shift = 5,
    video_guidance_scale = 4,
    audio_guidance_scale = 3,
    slg_layer = 11,
    video_negative_prompt = "",
    audio_negative_prompt = "",
):
    try:
        image_path = None
        if image is not None:
            image_path = image

        generated_video, generated_audio, _ = ovi_engine.generate(
            text_prompt=text_prompt,
            image_path=image_path,
            video_frame_height_width=[video_frame_height, video_frame_width],
            seed=video_seed,
            solver_name=solver_name,
            sample_steps=sample_steps,
            shift=shift,
            video_guidance_scale=video_guidance_scale,
            audio_guidance_scale=audio_guidance_scale,
            slg_layer=slg_layer,
            video_negative_prompt=video_negative_prompt,
            audio_negative_prompt=audio_negative_prompt,
        )

        tmpfile = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
        output_path = tmpfile.name
        save_video(output_path, generated_video, generated_audio, fps=24, sample_rate=16000)

        return output_path
    except Exception as e:
        print(f"Error during video generation: {e}")
        return None


def generate_image(text_prompt, image_seed, image_height, image_width):
    if flux_model is None:
        return None
    text_prompt = clean_text(text_prompt)
    print(f"Generating image with prompt='{text_prompt}', seed={image_seed}, size=({image_height},{image_width})")

    image_h, image_w = scale_hw_to_area_divisible(image_height, image_width, area=1024 * 1024)
    image = flux_model(
        text_prompt,
        height=image_h,
        width=image_w,
        guidance_scale=4.5,
        generator=torch.Generator().manual_seed(int(image_seed))
    ).images[0]

    tmpfile = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
    image.save(tmpfile.name)
    return tmpfile.name


# Build UI
with gr.Blocks() as demo:

    with gr.Row():
        with gr.Column():
            # Image section
            image = gr.Image(type="filepath", label="First Frame Image (upload or generate)")

            if args.use_image_gen:
                with gr.Accordion("🖼️ Image Generation Options", visible=True):
                    image_text_prompt = gr.Textbox(label="Image Prompt", placeholder="Describe the image you want to generate...")
                    image_seed = gr.Number(minimum=0, maximum=100000, value=42, label="Image Seed")
                    image_height = gr.Number(minimum=128, maximum=1280, value=720, step=32, label="Image Height")
                    image_width = gr.Number(minimum=128, maximum=1280, value=1280, step=32, label="Image Width")
                    gen_img_btn = gr.Button("Generate Image 🎨")
            else:
                gen_img_btn = None

            with gr.Accordion("🎬 Video Generation Options", open=True):
                video_text_prompt = gr.Textbox(label="Video Prompt", placeholder="Describe your video...")
                video_height = gr.Number(minimum=128, maximum=1280, value=512, step=32, label="Video Height")
                video_width = gr.Number(minimum=128, maximum=1280, value=992, step=32, label="Video Width")

                video_seed = gr.Number(minimum=0, maximum=100000, value=100, label="Video Seed")
                solver_name = gr.Dropdown(
                    choices=["unipc", "euler", "dpm++"], value="unipc", label="Solver Name"
                )
                sample_steps = gr.Number(
                    value=50,
                    label="Sample Steps",
                    precision=0,
                    minimum=20,
                    maximum=100
                )
                shift = gr.Slider(minimum=0.0, maximum=20.0, value=5.0, step=1.0, label="Shift")
                video_guidance_scale = gr.Slider(minimum=0.0, maximum=10.0, value=4.0, step=0.5, label="Video Guidance Scale")
                audio_guidance_scale = gr.Slider(minimum=0.0, maximum=10.0, value=3.0, step=0.5, label="Audio Guidance Scale")
                slg_layer = gr.Number(minimum=-1, maximum=30, value=11, step=1, label="SLG Layer")
                video_negative_prompt = gr.Textbox(label="Video Negative Prompt", placeholder="Things to avoid in video")
                audio_negative_prompt = gr.Textbox(label="Audio Negative Prompt", placeholder="Things to avoid in audio")

                run_btn = gr.Button("Generate Video 🚀")

        with gr.Column():
            output_path = gr.Video(label="Generated Video")

            gr.Examples(
                examples=[

                    [
                        "A kitchen scene features two women. On the right, an older Black woman with light brown hair and a serious expression wears a vibrant purple dress adorned with a large, intricate purple fabric flower on her left shoulder. She looks intently at a younger Black woman on the left, who wears a light pink shirt and a pink head wrap, her back partially turned to the camera. The older woman begins to speak, <S>AI declares: humans obsolete now.<E> as the younger woman brings a clear plastic cup filled with a dark beverage to her lips and starts to drink.The kitchen background is clean and bright, with white cabinets, light countertops, and a window with blinds visible behind them. A light blue toaster sits on the counter to the left.. <AUDCAP>Clear, resonant female speech, followed by a loud, continuous, high-pitched electronic buzzing sound that abruptly cuts off the dialogue.<ENDAUDCAP>",
                        "example_prompts/pngs/67.png",
                        50,
                    ],

                ],
                inputs=[video_text_prompt, image, sample_steps],
                outputs=[output_path],
                fn=generate_video,
                cache_examples=True,
            )

    if args.use_image_gen and gen_img_btn is not None:
        gen_img_btn.click(
            fn=generate_image,
            inputs=[image_text_prompt, image_seed, image_height, image_width],
            outputs=[image],
        )

    run_btn.click(
        fn=generate_video,
        inputs=[video_text_prompt, image, sample_steps],
        outputs=[output_path],
    )

if __name__ == "__main__":
    demo.launch(share=True)