|
|
import os |
|
|
import sys |
|
|
import time |
|
|
import torch |
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import imageio |
|
|
import spaces |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from videox_fun.ui.wan_ui import Wan_Controller, css |
|
|
from videox_fun.ui.ui import ( |
|
|
create_model_type, create_model_checkpoints, create_finetune_models_checkpoints, |
|
|
create_teacache_params, create_cfg_skip_params, create_cfg_riflex_k, |
|
|
create_prompts, create_samplers, create_height_width, |
|
|
create_generation_methods_and_video_length, create_generation_method, |
|
|
create_cfg_and_seedbox, create_ui_outputs |
|
|
) |
|
|
from videox_fun.data.dataset_image_video import derive_ground_object_from_instruction |
|
|
from videox_fun.utils.lora_utils import merge_lora, unmerge_lora |
|
|
from videox_fun.utils.utils import save_videos_grid, timer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_height_width_english(default_height, default_width, maximum_height, maximum_width): |
|
|
resize_method = gr.Radio( |
|
|
["Generate by", "Resize according to Reference"], |
|
|
value="Generate by", |
|
|
show_label=False, |
|
|
visible=False |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
width_slider = gr.Slider(label="Width", value=default_width, minimum=128, maximum=maximum_width, step=16, visible=False) |
|
|
height_slider = gr.Slider(label="Height", value=default_height, minimum=128, maximum=maximum_height, step=16, visible=False) |
|
|
base_resolution = gr.Radio(label="Base Resolution", value=512, choices=[512, 640, 768, 896, 960, 1024], visible=False) |
|
|
|
|
|
return resize_method, width_slider, height_slider, base_resolution |
|
|
|
|
|
def load_video_frames(video_path: str, source_frames: int): |
|
|
assert source_frames is not None, "source_frames is required" |
|
|
|
|
|
reader = imageio.get_reader(video_path) |
|
|
try: |
|
|
total_frames = reader.count_frames() |
|
|
except Exception: |
|
|
total_frames = sum(1 for _ in reader) |
|
|
reader = imageio.get_reader(video_path) |
|
|
|
|
|
stride = max(1, total_frames // source_frames) |
|
|
|
|
|
start_frame = torch.randint(0, max(1, total_frames - stride * source_frames), (1,))[0].item() |
|
|
|
|
|
frames = [] |
|
|
original_height, original_width = None, None |
|
|
|
|
|
for i in range(source_frames): |
|
|
idx = start_frame + i * stride |
|
|
if idx >= total_frames: |
|
|
break |
|
|
try: |
|
|
frame = reader.get_data(idx) |
|
|
pil_frame = Image.fromarray(frame) |
|
|
if original_height is None: |
|
|
original_width, original_height = pil_frame.size |
|
|
frames.append(pil_frame) |
|
|
except IndexError: |
|
|
break |
|
|
|
|
|
reader.close() |
|
|
|
|
|
while len(frames) < source_frames: |
|
|
if frames: |
|
|
frames.append(frames[-1].copy()) |
|
|
else: |
|
|
w, h = (original_width, original_height) if original_width else (832, 480) |
|
|
frames.append(Image.new('RGB', (w, h), (0, 0, 0))) |
|
|
|
|
|
input_video = torch.from_numpy(np.array(frames)) |
|
|
input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0).float() |
|
|
input_video = input_video * (2.0 / 255.0) - 1.0 |
|
|
|
|
|
return input_video, original_height, original_width |
|
|
|
|
|
class VideoCoF_Controller(Wan_Controller): |
|
|
@spaces.GPU(duration=300) |
|
|
@timer |
|
|
def generate( |
|
|
self, |
|
|
diffusion_transformer_dropdown, |
|
|
base_model_dropdown, |
|
|
lora_model_dropdown, |
|
|
lora_alpha_slider, |
|
|
prompt_textbox, |
|
|
negative_prompt_textbox, |
|
|
sampler_dropdown, |
|
|
sample_step_slider, |
|
|
resize_method, |
|
|
width_slider, |
|
|
height_slider, |
|
|
base_resolution, |
|
|
generation_method, |
|
|
length_slider, |
|
|
overlap_video_length, |
|
|
partial_video_length, |
|
|
cfg_scale_slider, |
|
|
start_image, |
|
|
end_image, |
|
|
validation_video, |
|
|
validation_video_mask, |
|
|
control_video, |
|
|
denoise_strength, |
|
|
seed_textbox, |
|
|
ref_image=None, |
|
|
|
|
|
source_frames_slider=33, |
|
|
reasoning_frames_slider=4, |
|
|
repeat_rope_checkbox=True, |
|
|
|
|
|
enable_acceleration=False, |
|
|
fps=8, |
|
|
is_api=False, |
|
|
): |
|
|
self.clear_cache() |
|
|
print(f"VideoCoF Generation started.") |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
self.device = torch.device("cuda") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.diffusion_transformer_dropdown != diffusion_transformer_dropdown: |
|
|
self.update_diffusion_transformer(diffusion_transformer_dropdown) |
|
|
|
|
|
if self.base_model_path != base_model_dropdown: |
|
|
self.update_base_model(base_model_dropdown) |
|
|
|
|
|
if self.lora_model_path != lora_model_dropdown: |
|
|
self.update_lora_model(lora_model_dropdown) |
|
|
|
|
|
|
|
|
scheduler_config = self.pipeline.scheduler.config |
|
|
if sampler_dropdown in ["Flow_Unipc", "Flow_DPM++"]: |
|
|
scheduler_config['shift'] = 1 |
|
|
self.pipeline.scheduler = self.scheduler_dict[sampler_dropdown].from_config(scheduler_config) |
|
|
|
|
|
|
|
|
|
|
|
if self.lora_model_path != "none": |
|
|
print(f"Merge VideoCoF Lora: {self.lora_model_path}") |
|
|
self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider) |
|
|
|
|
|
|
|
|
acc_lora_path = os.path.join(self.personalized_model_dir, "Wan2.1_Text_to_Video_14B_FusionX_LoRA.safetensors") |
|
|
if enable_acceleration: |
|
|
if os.path.exists(acc_lora_path): |
|
|
print(f"Merge Acceleration LoRA: {acc_lora_path}") |
|
|
|
|
|
self.pipeline = merge_lora(self.pipeline, acc_lora_path, multiplier=1.0) |
|
|
else: |
|
|
print(f"Warning: Acceleration LoRA not found at {acc_lora_path}") |
|
|
|
|
|
|
|
|
if int(seed_textbox) != -1 and seed_textbox != "": |
|
|
torch.manual_seed(int(seed_textbox)) |
|
|
else: |
|
|
seed_textbox = np.random.randint(0, 1e10) |
|
|
generator = torch.Generator(device=self.device).manual_seed(int(seed_textbox)) |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
input_video_path = validation_video |
|
|
|
|
|
if input_video_path is None: |
|
|
|
|
|
input_video_path = control_video |
|
|
|
|
|
if input_video_path is None: |
|
|
raise ValueError("Please upload a video for VideoCoF generation.") |
|
|
|
|
|
|
|
|
edit_text = prompt_textbox |
|
|
ground_instr = derive_ground_object_from_instruction(edit_text) |
|
|
prompt = ( |
|
|
"A video sequence showing three parts: first the original scene, " |
|
|
f"then grounded {ground_instr}, and finally the same scene but {edit_text}" |
|
|
) |
|
|
print(f"Constructed prompt: {prompt}") |
|
|
|
|
|
|
|
|
input_video_tensor, video_height, video_width = load_video_frames( |
|
|
input_video_path, |
|
|
source_frames=source_frames_slider |
|
|
) |
|
|
|
|
|
|
|
|
h, w = video_height, video_width |
|
|
print(f"Input video dimensions: {w}x{h}") |
|
|
|
|
|
print(f"Running pipeline with frames={length_slider}, source={source_frames_slider}, reasoning={reasoning_frames_slider}") |
|
|
|
|
|
sample = self.pipeline( |
|
|
video=input_video_tensor, |
|
|
prompt=prompt, |
|
|
num_frames=length_slider, |
|
|
source_frames=source_frames_slider, |
|
|
reasoning_frames=reasoning_frames_slider, |
|
|
negative_prompt=negative_prompt_textbox, |
|
|
height=h, |
|
|
width=w, |
|
|
generator=generator, |
|
|
guidance_scale=cfg_scale_slider, |
|
|
num_inference_steps=sample_step_slider, |
|
|
repeat_rope=repeat_rope_checkbox, |
|
|
cot=True, |
|
|
).videos |
|
|
|
|
|
final_video = sample |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error: {e}") |
|
|
|
|
|
if enable_acceleration and os.path.exists(acc_lora_path): |
|
|
print("Unmerging Acceleration LoRA (due to error)") |
|
|
self.pipeline = unmerge_lora(self.pipeline, acc_lora_path, multiplier=1.0) |
|
|
|
|
|
if self.lora_model_path != "none": |
|
|
print("Unmerging VideoCoF LoRA (due to error)") |
|
|
self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider) |
|
|
return gr.update(), gr.update(), f"Error: {str(e)}" |
|
|
|
|
|
|
|
|
if enable_acceleration and os.path.exists(acc_lora_path): |
|
|
print("Unmerging Acceleration LoRA") |
|
|
self.pipeline = unmerge_lora(self.pipeline, acc_lora_path, multiplier=1.0) |
|
|
|
|
|
if self.lora_model_path != "none": |
|
|
print("Unmerging VideoCoF LoRA") |
|
|
self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider) |
|
|
|
|
|
|
|
|
save_sample_path = self.save_outputs( |
|
|
False, length_slider, final_video, fps=fps |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success" |
|
|
|
|
|
def ui(GPU_memory_mode, scheduler_dict, config_path, compile_dit, weight_dtype): |
|
|
controller = VideoCoF_Controller( |
|
|
GPU_memory_mode, scheduler_dict, model_name=None, model_type="Inpaint", |
|
|
config_path=config_path, compile_dit=compile_dit, |
|
|
weight_dtype=weight_dtype |
|
|
) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# VideoCoF Demo") |
|
|
|
|
|
with gr.Column(variant="panel"): |
|
|
|
|
|
diffusion_transformer_dropdown, _ = create_model_checkpoints(controller, visible=False, default_model="Wan-AI/Wan2.1-T2V-14B") |
|
|
|
|
|
|
|
|
try: |
|
|
from huggingface_hub import snapshot_download, hf_hub_download |
|
|
print("Downloading Wan2.1-T2V-14B weights...") |
|
|
snapshot_download(repo_id="Wan-AI/Wan2.1-T2V-14B", local_dir="Wan-AI/Wan2.1-T2V-14B") |
|
|
|
|
|
os.makedirs("models/Personalized_Model", exist_ok=True) |
|
|
|
|
|
print("Downloading VideoCoF weights...") |
|
|
hf_hub_download(repo_id="XiangpengYang/VideoCoF", filename="videocof.safetensors", local_dir="models/Personalized_Model") |
|
|
|
|
|
print("Downloading FusionX Acceleration LoRA...") |
|
|
hf_hub_download(repo_id="MonsterMMORPG/Wan_GGUF", filename="Wan2.1_Text_to_Video_14B_FusionX_LoRA.safetensors", local_dir="models/Personalized_Model") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Warning: Failed to pre-download weights: {e}") |
|
|
|
|
|
base_model_dropdown, lora_model_dropdown, lora_alpha_slider, _ = create_finetune_models_checkpoints(controller, visible=False, default_lora="videocof.safetensors") |
|
|
|
|
|
|
|
|
lora_alpha_slider.value = 1.0 |
|
|
|
|
|
with gr.Column(variant="panel"): |
|
|
prompt_textbox, negative_prompt_textbox = create_prompts(prompt="Remove the young man with short black hair wearing black shirt on the left.") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
sampler_dropdown, sample_step_slider = create_samplers(controller) |
|
|
|
|
|
|
|
|
sample_step_slider.value = 4 |
|
|
|
|
|
|
|
|
with gr.Group(): |
|
|
gr.Markdown("### VideoCoF Parameters") |
|
|
source_frames_slider = gr.Slider(label="Source Frames", minimum=1, maximum=100, value=33, step=1) |
|
|
reasoning_frames_slider = gr.Slider(label="Reasoning Frames", minimum=1, maximum=20, value=4, step=1) |
|
|
repeat_rope_checkbox = gr.Checkbox(label="Repeat RoPE", value=True) |
|
|
|
|
|
enable_acceleration = gr.Checkbox(label="Enable 4-step Acceleration (FusionX LoRA)", value=False) |
|
|
|
|
|
|
|
|
resize_method, width_slider, height_slider, base_resolution = create_height_width_english( |
|
|
default_height=480, default_width=832, maximum_height=1344, maximum_width=1344 |
|
|
) |
|
|
|
|
|
|
|
|
generation_method, length_slider, overlap_video_length, partial_video_length = \ |
|
|
create_generation_methods_and_video_length( |
|
|
["Video Generation"], |
|
|
default_video_length=65, |
|
|
maximum_video_length=161 |
|
|
) |
|
|
|
|
|
|
|
|
image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method( |
|
|
["Video to Video"], prompt_textbox, support_end_image=False, default_video="assets/two_man.mp4", |
|
|
video_examples=[ |
|
|
["assets/two_man.mp4", "Remove the young man with short black hair wearing black shirt on the left."], |
|
|
["assets/sign.mp4", "Replace the yellow \"SCHOOL\" sign with a red hospital sign, featuring a white hospital emblem on the top and the word \"HOSPITAL\" below."] |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
validation_video.visible = True |
|
|
validation_video.interactive = True |
|
|
|
|
|
|
|
|
cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(True) |
|
|
seed_textbox.value = "0" |
|
|
cfg_scale_slider.value = 1.0 |
|
|
|
|
|
generate_button = gr.Button(value="Generate", variant='primary') |
|
|
|
|
|
result_image, result_video, infer_progress = create_ui_outputs() |
|
|
|
|
|
|
|
|
generate_button.click( |
|
|
fn=controller.generate, |
|
|
inputs=[ |
|
|
diffusion_transformer_dropdown, |
|
|
base_model_dropdown, |
|
|
lora_model_dropdown, |
|
|
lora_alpha_slider, |
|
|
prompt_textbox, |
|
|
negative_prompt_textbox, |
|
|
sampler_dropdown, |
|
|
sample_step_slider, |
|
|
resize_method, |
|
|
width_slider, |
|
|
height_slider, |
|
|
base_resolution, |
|
|
generation_method, |
|
|
length_slider, |
|
|
overlap_video_length, |
|
|
partial_video_length, |
|
|
cfg_scale_slider, |
|
|
start_image, |
|
|
end_image, |
|
|
validation_video, |
|
|
validation_video_mask, |
|
|
control_video, |
|
|
denoise_strength, |
|
|
seed_textbox, |
|
|
ref_image, |
|
|
|
|
|
source_frames_slider, |
|
|
reasoning_frames_slider, |
|
|
repeat_rope_checkbox, |
|
|
enable_acceleration |
|
|
], |
|
|
outputs=[result_image, result_video, infer_progress] |
|
|
) |
|
|
|
|
|
return demo, controller |
|
|
|
|
|
if __name__ == "__main__": |
|
|
from videox_fun.ui.controller import flow_scheduler_dict |
|
|
|
|
|
GPU_memory_mode = "sequential_cpu_offload" |
|
|
compile_dit = False |
|
|
weight_dtype = torch.bfloat16 |
|
|
server_name = "0.0.0.0" |
|
|
server_port = 7860 |
|
|
config_path = "config/wan2.1/wan_civitai.yaml" |
|
|
|
|
|
demo, controller = ui(GPU_memory_mode, flow_scheduler_dict, config_path, compile_dit, weight_dtype) |
|
|
|
|
|
demo.queue(status_update_rate=1).launch( |
|
|
server_name=server_name, |
|
|
server_port=server_port, |
|
|
prevent_thread_lock=True, |
|
|
share=False |
|
|
) |
|
|
|
|
|
while True: |
|
|
time.sleep(5) |
|
|
|