| import os |
| import sys |
| import time |
| import torch |
| import gradio as gr |
| import numpy as np |
| import imageio |
| from PIL import Image |
|
|
| |
| |
| |
| |
| |
|
|
| from videox_fun.ui.wan_ui import Wan_Controller, css |
| from videox_fun.ui.ui import ( |
| create_model_type, create_model_checkpoints, create_finetune_models_checkpoints, |
| create_teacache_params, create_cfg_skip_params, create_cfg_riflex_k, |
| create_prompts, create_samplers, create_height_width, |
| create_generation_methods_and_video_length, create_generation_method, |
| create_cfg_and_seedbox, create_ui_outputs |
| ) |
| from videox_fun.data.dataset_image_video import derive_ground_object_from_instruction |
| from videox_fun.utils.lora_utils import merge_lora, unmerge_lora |
| from videox_fun.utils.utils import save_videos_grid, timer |
|
|
| |
| |
| |
| |
| |
|
|
| def create_height_width_english(default_height, default_width, maximum_height, maximum_width): |
| resize_method = gr.Radio( |
| ["Generate by", "Resize according to Reference"], |
| value="Generate by", |
| show_label=False, |
| visible=False |
| ) |
| |
| |
| |
| |
| |
| |
| |
| |
| width_slider = gr.Slider(label="Width", value=default_width, minimum=128, maximum=maximum_width, step=16, visible=False) |
| height_slider = gr.Slider(label="Height", value=default_height, minimum=128, maximum=maximum_height, step=16, visible=False) |
| base_resolution = gr.Radio(label="Base Resolution", value=512, choices=[512, 640, 768, 896, 960, 1024], visible=False) |
|
|
| return resize_method, width_slider, height_slider, base_resolution |
|
|
| def load_video_frames(video_path: str, source_frames: int): |
| assert source_frames is not None, "source_frames is required" |
| |
| reader = imageio.get_reader(video_path) |
| try: |
| total_frames = reader.count_frames() |
| except Exception: |
| total_frames = sum(1 for _ in reader) |
| reader = imageio.get_reader(video_path) |
|
|
| stride = max(1, total_frames // source_frames) |
| |
| start_frame = torch.randint(0, max(1, total_frames - stride * source_frames), (1,))[0].item() |
|
|
| frames = [] |
| original_height, original_width = None, None |
|
|
| for i in range(source_frames): |
| idx = start_frame + i * stride |
| if idx >= total_frames: |
| break |
| try: |
| frame = reader.get_data(idx) |
| pil_frame = Image.fromarray(frame) |
| if original_height is None: |
| original_width, original_height = pil_frame.size |
| frames.append(pil_frame) |
| except IndexError: |
| break |
|
|
| reader.close() |
|
|
| while len(frames) < source_frames: |
| if frames: |
| frames.append(frames[-1].copy()) |
| else: |
| w, h = (original_width, original_height) if original_width else (832, 480) |
| frames.append(Image.new('RGB', (w, h), (0, 0, 0))) |
|
|
| input_video = torch.from_numpy(np.array(frames)) |
| input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0).float() |
| input_video = input_video * (2.0 / 255.0) - 1.0 |
|
|
| return input_video, original_height, original_width |
|
|
| class VideoCoF_Controller(Wan_Controller): |
| @timer |
| def generate( |
| self, |
| diffusion_transformer_dropdown, |
| base_model_dropdown, |
| lora_model_dropdown, |
| lora_alpha_slider, |
| prompt_textbox, |
| negative_prompt_textbox, |
| sampler_dropdown, |
| sample_step_slider, |
| resize_method, |
| width_slider, |
| height_slider, |
| base_resolution, |
| generation_method, |
| length_slider, |
| overlap_video_length, |
| partial_video_length, |
| cfg_scale_slider, |
| start_image, |
| end_image, |
| validation_video, |
| validation_video_mask, |
| control_video, |
| denoise_strength, |
| seed_textbox, |
| ref_image=None, |
| enable_teacache=None, |
| teacache_threshold=None, |
| num_skip_start_steps=None, |
| teacache_offload=None, |
| cfg_skip_ratio=None, |
| enable_riflex=None, |
| riflex_k=None, |
| |
| source_frames_slider=33, |
| reasoning_frames_slider=4, |
| repeat_rope_checkbox=True, |
| fps=10, |
| is_api=False, |
| ): |
| self.clear_cache() |
| print(f"VideoCoF Generation started.") |
|
|
| if self.diffusion_transformer_dropdown != diffusion_transformer_dropdown: |
| self.update_diffusion_transformer(diffusion_transformer_dropdown) |
|
|
| if self.base_model_path != base_model_dropdown: |
| self.update_base_model(base_model_dropdown) |
|
|
| if self.lora_model_path != lora_model_dropdown: |
| self.update_lora_model(lora_model_dropdown) |
|
|
| |
| scheduler_config = self.pipeline.scheduler.config |
| if sampler_dropdown in ["Flow_Unipc", "Flow_DPM++"]: |
| scheduler_config['shift'] = 1 |
| self.pipeline.scheduler = self.scheduler_dict[sampler_dropdown].from_config(scheduler_config) |
|
|
| |
| if self.lora_model_path != "none": |
| print(f"Merge Lora.") |
| self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider) |
|
|
| |
| if int(seed_textbox) != -1 and seed_textbox != "": |
| torch.manual_seed(int(seed_textbox)) |
| else: |
| seed_textbox = np.random.randint(0, 1e10) |
| generator = torch.Generator(device=self.device).manual_seed(int(seed_textbox)) |
|
|
| try: |
| |
| |
| input_video_path = validation_video |
| |
| if input_video_path is None: |
| |
| input_video_path = control_video |
|
|
| if input_video_path is None: |
| raise ValueError("Please upload a video for VideoCoF generation.") |
|
|
| |
| edit_text = prompt_textbox |
| ground_instr = derive_ground_object_from_instruction(edit_text) |
| prompt = ( |
| "A video sequence showing three parts: first the original scene, " |
| f"then grounded {ground_instr}, and finally the same scene but {edit_text}" |
| ) |
| print(f"Constructed prompt: {prompt}") |
|
|
| |
| input_video_tensor, video_height, video_width = load_video_frames( |
| input_video_path, |
| source_frames=source_frames_slider |
| ) |
|
|
| |
| h, w = video_height, video_width |
| print(f"Input video dimensions: {w}x{h}") |
|
|
| print(f"Running pipeline with frames={length_slider}, source={source_frames_slider}, reasoning={reasoning_frames_slider}") |
| |
| sample = self.pipeline( |
| video=input_video_tensor, |
| prompt=prompt, |
| num_frames=length_slider, |
| source_frames=source_frames_slider, |
| reasoning_frames=reasoning_frames_slider, |
| negative_prompt=negative_prompt_textbox, |
| height=h, |
| width=w, |
| generator=generator, |
| guidance_scale=cfg_scale_slider, |
| num_inference_steps=sample_step_slider, |
| repeat_rope=repeat_rope_checkbox, |
| cot=True, |
| ).videos |
|
|
| final_video = sample |
|
|
| except Exception as e: |
| print(f"Error: {e}") |
| if self.lora_model_path != "none": |
| self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider) |
| return gr.update(), gr.update(), f"Error: {str(e)}" |
|
|
| |
| if self.lora_model_path != "none": |
| self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider) |
|
|
| |
| save_sample_path = self.save_outputs( |
| False, length_slider, final_video, fps=fps |
| ) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success" |
|
|
| def ui(GPU_memory_mode, scheduler_dict, config_path, compile_dit, weight_dtype): |
| controller = VideoCoF_Controller( |
| GPU_memory_mode, scheduler_dict, model_name=None, model_type="Inpaint", |
| config_path=config_path, compile_dit=compile_dit, |
| weight_dtype=weight_dtype |
| ) |
|
|
| with gr.Blocks() as demo: |
| gr.Markdown("# VideoCoF Demo") |
| |
| with gr.Column(variant="panel"): |
| |
| diffusion_transformer_dropdown, _ = create_model_checkpoints(controller, visible=False, default_model="Wan-AI/Wan2.1-T2V-14B") |
| base_model_dropdown, lora_model_dropdown, lora_alpha_slider, _ = create_finetune_models_checkpoints(controller, visible=False, default_lora="XiangpengYang/VideoCoF") |
| |
| |
| lora_alpha_slider.value = 1.0 |
|
|
| with gr.Row(): |
| |
| enable_teacache, teacache_threshold, num_skip_start_steps, teacache_offload = create_teacache_params(False, 0.10, 5, False) |
| cfg_skip_ratio = create_cfg_skip_params(0) |
| enable_riflex, riflex_k = create_cfg_riflex_k(False, 6) |
|
|
| with gr.Column(variant="panel"): |
| prompt_textbox, negative_prompt_textbox = create_prompts(prompt="Remove the young man with short black hair wearing black shirt on the left.") |
| |
| with gr.Row(): |
| with gr.Column(): |
| sampler_dropdown, sample_step_slider = create_samplers(controller) |
| |
| |
| with gr.Group(): |
| gr.Markdown("### VideoCoF Parameters") |
| source_frames_slider = gr.Slider(label="Source Frames", minimum=1, maximum=100, value=33, step=1) |
| reasoning_frames_slider = gr.Slider(label="Reasoning Frames", minimum=1, maximum=20, value=4, step=1) |
| repeat_rope_checkbox = gr.Checkbox(label="Repeat RoPE", value=True) |
| |
| |
| resize_method, width_slider, height_slider, base_resolution = create_height_width_english( |
| default_height=480, default_width=832, maximum_height=1344, maximum_width=1344 |
| ) |
| |
| |
| generation_method, length_slider, overlap_video_length, partial_video_length = \ |
| create_generation_methods_and_video_length( |
| ["Video Generation"], |
| default_video_length=65, |
| maximum_video_length=161 |
| ) |
| |
| |
| image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method( |
| ["Video to Video"], prompt_textbox, support_end_image=False, default_video="assets/two_man.mp4", |
| video_examples=[ |
| ["assets/two_man.mp4", "Remove the young man with short black hair wearing black shirt on the left."], |
| ["assets/sign.mp4", "Replace the yellow \"SCHOOL\" sign with a red hospital sign, featuring a white hospital emblem on the top and the word \"HOSPITAL\" below."] |
| ] |
| ) |
| |
| |
| validation_video.visible = True |
| validation_video.interactive = True |
|
|
| |
| cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(True) |
| seed_textbox.value = "0" |
| |
| generate_button = gr.Button(value="Generate", variant='primary') |
|
|
| result_image, result_video, infer_progress = create_ui_outputs() |
|
|
| |
| generate_button.click( |
| fn=controller.generate, |
| inputs=[ |
| diffusion_transformer_dropdown, |
| base_model_dropdown, |
| lora_model_dropdown, |
| lora_alpha_slider, |
| prompt_textbox, |
| negative_prompt_textbox, |
| sampler_dropdown, |
| sample_step_slider, |
| resize_method, |
| width_slider, |
| height_slider, |
| base_resolution, |
| generation_method, |
| length_slider, |
| overlap_video_length, |
| partial_video_length, |
| cfg_scale_slider, |
| start_image, |
| end_image, |
| validation_video, |
| validation_video_mask, |
| control_video, |
| denoise_strength, |
| seed_textbox, |
| ref_image, |
| enable_teacache, |
| teacache_threshold, |
| num_skip_start_steps, |
| teacache_offload, |
| cfg_skip_ratio, |
| enable_riflex, |
| riflex_k, |
| |
| source_frames_slider, |
| reasoning_frames_slider, |
| repeat_rope_checkbox |
| ], |
| outputs=[result_image, result_video, infer_progress] |
| ) |
|
|
| return demo, controller |
|
|
| if __name__ == "__main__": |
| from videox_fun.ui.controller import flow_scheduler_dict |
| |
| GPU_memory_mode = "sequential_cpu_offload" |
| compile_dit = False |
| weight_dtype = torch.bfloat16 |
| server_name = "0.0.0.0" |
| server_port = 7860 |
| config_path = "config/wan2.1/wan_civitai.yaml" |
|
|
| demo, controller = ui(GPU_memory_mode, flow_scheduler_dict, config_path, compile_dit, weight_dtype) |
| |
| demo.queue(status_update_rate=1).launch( |
| server_name=server_name, |
| server_port=server_port, |
| prevent_thread_lock=True, |
| share=False |
| ) |
| |
| while True: |
| time.sleep(5) |
|
|