Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import sys | |
| import time | |
| import torch | |
| import gradio as gr | |
| import numpy as np | |
| import imageio | |
| import spaces | |
| from PIL import Image | |
| # Add project root to path | |
| # current_file_path = os.path.abspath(__file__) | |
| # project_root = os.path.dirname(os.path.dirname(current_file_path)) | |
| # if project_root not in sys.path: | |
| # sys.path.insert(0, project_root) | |
| from videox_fun.ui.wan_ui import Wan_Controller, css | |
| from videox_fun.ui.ui import ( | |
| create_model_type, create_model_checkpoints, create_finetune_models_checkpoints, | |
| create_teacache_params, create_cfg_skip_params, create_cfg_riflex_k, | |
| create_prompts, create_samplers, create_height_width, | |
| create_generation_methods_and_video_length, create_generation_method, | |
| create_cfg_and_seedbox, create_ui_outputs | |
| ) | |
| from videox_fun.data.dataset_image_video import derive_ground_object_from_instruction | |
| from videox_fun.utils.lora_utils import merge_lora, unmerge_lora | |
| from videox_fun.utils.utils import save_videos_grid, timer | |
| # Redefine create_height_width to remove Chinese and specific defaults if needed, | |
| # although we will mostly ignore sliders if we use input resolution. | |
| # We will create a custom version here to avoid modifying the library file if possible, | |
| # or we just rely on `create_height_width` and update labels. | |
| # But `create_height_width` is imported. Let's override it or create a new one. | |
| def create_height_width_english(default_height, default_width, maximum_height, maximum_width): | |
| resize_method = gr.Radio( | |
| ["Generate by", "Resize according to Reference"], | |
| value="Generate by", | |
| show_label=False, | |
| visible=False # Hide since we force input resolution | |
| ) | |
| # We keep sliders visible but maybe we can update them dynamically or just ignore them? | |
| # User requested "input is whatever resolution, inference is whatever resolution". | |
| # So we can hide these or just label them as "Default / Override if no video". | |
| # But better to hide them if we always use video resolution. | |
| # However, if no video is provided (which shouldn't happen for VideoCoF), we might need them. | |
| # Let's keep them but make them less prominent or explain. | |
| # Actually user said "no default 480x832", implying don't force it. | |
| width_slider = gr.Slider(label="Width", value=default_width, minimum=128, maximum=maximum_width, step=16, visible=False) | |
| height_slider = gr.Slider(label="Height", value=default_height, minimum=128, maximum=maximum_height, step=16, visible=False) | |
| base_resolution = gr.Radio(label="Base Resolution", value=512, choices=[512, 640, 768, 896, 960, 1024], visible=False) | |
| return resize_method, width_slider, height_slider, base_resolution | |
| def load_video_frames(video_path: str, source_frames: int): | |
| assert source_frames is not None, "source_frames is required" | |
| reader = imageio.get_reader(video_path) | |
| try: | |
| total_frames = reader.count_frames() | |
| except Exception: | |
| total_frames = sum(1 for _ in reader) | |
| reader = imageio.get_reader(video_path) | |
| stride = max(1, total_frames // source_frames) | |
| # Using random start frame as in inference.py | |
| start_frame = torch.randint(0, max(1, total_frames - stride * source_frames), (1,))[0].item() | |
| frames = [] | |
| original_height, original_width = None, None | |
| for i in range(source_frames): | |
| idx = start_frame + i * stride | |
| if idx >= total_frames: | |
| break | |
| try: | |
| frame = reader.get_data(idx) | |
| pil_frame = Image.fromarray(frame) | |
| if original_height is None: | |
| original_width, original_height = pil_frame.size | |
| print(f"Original video dimensions: {original_width}x{original_height}") | |
| frames.append(pil_frame) | |
| except IndexError: | |
| break | |
| reader.close() | |
| while len(frames) < source_frames: | |
| if frames: | |
| frames.append(frames[-1].copy()) | |
| else: | |
| w, h = (original_width, original_height) if original_width else (832, 480) | |
| frames.append(Image.new('RGB', (w, h), (0, 0, 0))) | |
| assert len(frames) == source_frames, f"Loaded {len(frames)} frames, expected {source_frames}" | |
| print(f"Loaded {source_frames} source frames") | |
| input_video = torch.from_numpy(np.array(frames)) | |
| input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0).float() | |
| input_video = input_video * (2.0 / 255.0) - 1.0 | |
| return input_video, original_height, original_width | |
| def preload_models(controller, default_model_path, default_lora_name, acc_lora_path): | |
| """ | |
| Preload base model and LoRAs before launching the app to avoid first-run latency. | |
| """ | |
| controller.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Ensure tracking flags exist | |
| if not hasattr(controller, "_active_lora_path"): | |
| controller._active_lora_path = None | |
| if not hasattr(controller, "_acc_lora_active"): | |
| controller._acc_lora_active = False | |
| try: | |
| print(f"[preload] Loading base model: {default_model_path}") | |
| controller.update_diffusion_transformer(default_model_path) | |
| # update_base_model expects files under Personalized_Model; skip if not present | |
| base_candidate = os.path.join(controller.personalized_model_dir, os.path.basename(default_model_path)) | |
| if os.path.exists(base_candidate): | |
| controller.update_base_model(os.path.basename(base_candidate)) | |
| else: | |
| print(f"[preload] Skip update_base_model (not found at {base_candidate})") | |
| print(f"[preload] Loading VideoCoF LoRA: {default_lora_name}") | |
| controller.update_lora_model(default_lora_name) | |
| if controller.lora_model_path and controller.lora_model_path != "none": | |
| controller.pipeline = merge_lora( | |
| controller.pipeline, | |
| controller.lora_model_path, | |
| multiplier=1.0, | |
| device=controller.device, | |
| ) | |
| controller._active_lora_path = controller.lora_model_path | |
| if acc_lora_path and os.path.exists(acc_lora_path): | |
| print(f"[preload] Loading Acceleration LoRA: {acc_lora_path}") | |
| controller.pipeline = merge_lora( | |
| controller.pipeline, acc_lora_path, multiplier=1.0, device=controller.device | |
| ) | |
| controller._acc_lora_active = True | |
| else: | |
| print(f"[preload] Acceleration LoRA not found at {acc_lora_path}") | |
| except Exception as e: | |
| print(f"[preload] Warning: preload failed: {e}") | |
| finally: | |
| torch.cuda.empty_cache() | |
| class VideoCoF_Controller(Wan_Controller): | |
| def generate( | |
| self, | |
| diffusion_transformer_dropdown, | |
| base_model_dropdown, | |
| lora_model_dropdown, | |
| lora_alpha_slider, | |
| prompt_textbox, | |
| negative_prompt_textbox, | |
| sampler_dropdown, | |
| sample_step_slider, | |
| resize_method, | |
| width_slider, | |
| height_slider, | |
| base_resolution, | |
| generation_method, | |
| length_slider, | |
| overlap_video_length, | |
| partial_video_length, | |
| cfg_scale_slider, | |
| start_image, | |
| end_image, | |
| validation_video, | |
| validation_video_mask, | |
| control_video, | |
| denoise_strength, | |
| seed_textbox, | |
| ref_image=None, | |
| # Custom args | |
| source_frames_slider=33, | |
| reasoning_frames_slider=4, | |
| repeat_rope_checkbox=True, | |
| # New arg for acceleration | |
| enable_acceleration=True, | |
| fps=8, | |
| is_api=False, | |
| ): | |
| self.clear_cache() | |
| print(f"VideoCoF Generation started.") | |
| # Ensure model is on CUDA inside the zero-gpu decorated function | |
| if torch.cuda.is_available(): | |
| self.device = torch.device("cuda") | |
| else: | |
| self.device = torch.device("cpu") | |
| # If pipeline is not on cuda, move it (if possible, but usually accelerate handles this or it's handled by parts) | |
| # However, Wan_Controller logic might rely on `self.device`. | |
| # We explicitly set `self.device` to cuda here. | |
| # Ensure pipeline modules are on the chosen device (avoid CPU ops) | |
| try: | |
| if hasattr(self, "pipeline") and self.pipeline is not None: | |
| self.pipeline.to(self.device) | |
| except Exception as move_e: | |
| print(f"Warning: failed to move pipeline to {self.device}: {move_e}") | |
| if self.diffusion_transformer_dropdown != diffusion_transformer_dropdown: | |
| self.update_diffusion_transformer(diffusion_transformer_dropdown) | |
| if self.base_model_path != base_model_dropdown: | |
| self.update_base_model(base_model_dropdown) | |
| if self.lora_model_path != lora_model_dropdown: | |
| self.update_lora_model(lora_model_dropdown) | |
| # Track whether LoRAs are already merged to avoid repeat merges/unmerges. | |
| if not hasattr(self, "_active_lora_path"): | |
| self._active_lora_path = None | |
| if not hasattr(self, "_acc_lora_active"): | |
| self._acc_lora_active = False | |
| # Scheduler setup | |
| scheduler_config = self.pipeline.scheduler.config | |
| if sampler_dropdown in ["Flow_Unipc", "Flow_DPM++"]: | |
| scheduler_config['shift'] = 1 | |
| self.pipeline.scheduler = self.scheduler_dict[sampler_dropdown].from_config(scheduler_config) | |
| # LoRA merging | |
| # 1. Merge VideoCoF LoRA | |
| if self.lora_model_path != "none": | |
| # If a different LoRA was previously merged, unmerge it first. | |
| if self._active_lora_path and self._active_lora_path != self.lora_model_path: | |
| print(f"Unmerging previous VideoCoF LoRA: {self._active_lora_path}") | |
| self.pipeline = unmerge_lora(self.pipeline, self._active_lora_path, multiplier=lora_alpha_slider, device=self.device) | |
| self._active_lora_path = None | |
| if self._active_lora_path != self.lora_model_path: | |
| print(f"Merge VideoCoF LoRA: {self.lora_model_path}") | |
| self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider, device=self.device) | |
| self._active_lora_path = self.lora_model_path | |
| # 2. Merge Acceleration LoRA (FusionX) if enabled | |
| acc_lora_path = os.path.join(self.personalized_model_dir, "Wan2.1_Text_to_Video_14B_FusionX_LoRA.safetensors") | |
| if enable_acceleration: | |
| if os.path.exists(acc_lora_path): | |
| if not self._acc_lora_active: | |
| print(f"Merge Acceleration LoRA: {acc_lora_path}") | |
| # FusionX LoRA generally uses multiplier 1.0 | |
| self.pipeline = merge_lora(self.pipeline, acc_lora_path, multiplier=1.0, device=self.device) | |
| self._acc_lora_active = True | |
| else: | |
| print(f"Warning: Acceleration LoRA not found at {acc_lora_path}") | |
| else: | |
| # If it was previously merged but now disabled, unmerge once. | |
| if self._acc_lora_active and os.path.exists(acc_lora_path): | |
| print("Unmerging Acceleration LoRA (disabled)") | |
| self.pipeline = unmerge_lora(self.pipeline, acc_lora_path, multiplier=1.0, device=self.device) | |
| self._acc_lora_active = False | |
| # Seed | |
| if int(seed_textbox) != -1 and seed_textbox != "": | |
| torch.manual_seed(int(seed_textbox)) | |
| else: | |
| seed_textbox = np.random.randint(0, 1e10) | |
| # Ensure generator is created on the same device as the pipeline's transformer | |
| gen_device = getattr(getattr(self, "pipeline", None), "transformer", None) | |
| gen_device = gen_device.device if gen_device is not None else self.device | |
| generator = torch.Generator(device=gen_device).manual_seed(int(seed_textbox)) | |
| try: | |
| # VideoCoF logic | |
| # Use validation_video as source if provided (UI standard for Video-to-Video) | |
| input_video_path = validation_video | |
| if input_video_path is None: | |
| # Fallback to control_video if set, but standard UI uses validation_video | |
| input_video_path = control_video | |
| if input_video_path is None: | |
| raise ValueError("Please upload a video for VideoCoF generation.") | |
| # CoT Prompt Construction | |
| edit_text = prompt_textbox | |
| ground_instr = derive_ground_object_from_instruction(edit_text) | |
| prompt = ( | |
| "A video sequence showing three parts: first the original scene, " | |
| f"then grounded {ground_instr}, and finally the same scene but {edit_text}" | |
| ) | |
| print(f"Constructed prompt: {prompt}") | |
| # Load video frames | |
| input_video_tensor, video_height, video_width = load_video_frames( | |
| input_video_path, | |
| source_frames=source_frames_slider | |
| ) | |
| # Using loaded video dimensions | |
| h, w = video_height, video_width | |
| print(f"Input video dimensions: {w}x{h}") | |
| print(f"Running pipeline with frames={length_slider}, source={source_frames_slider}, reasoning={reasoning_frames_slider}") | |
| shift = 3 | |
| sample = self.pipeline( | |
| video=input_video_tensor, | |
| prompt=prompt, | |
| num_frames=length_slider, | |
| source_frames=source_frames_slider, | |
| reasoning_frames=reasoning_frames_slider, | |
| negative_prompt=negative_prompt_textbox, | |
| height=h, | |
| width=w, | |
| generator=generator, | |
| guidance_scale=cfg_scale_slider, | |
| num_inference_steps=sample_step_slider, | |
| shift=shift, | |
| repeat_rope=repeat_rope_checkbox, | |
| cot=True, | |
| ).videos | |
| # Keep only the edited segment (drop reasoning/original parts) | |
| final_video = sample[:, :, -source_frames_slider:, :, :] | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| # Unmerge in case of error (LIFO order) | |
| if self._acc_lora_active and os.path.exists(acc_lora_path): | |
| print("Unmerging Acceleration LoRA (due to error)") | |
| self.pipeline = unmerge_lora(self.pipeline, acc_lora_path, multiplier=1.0, device=self.device) | |
| self._acc_lora_active = False | |
| if self._active_lora_path: | |
| print("Unmerging VideoCoF LoRA (due to error)") | |
| self.pipeline = unmerge_lora(self.pipeline, self._active_lora_path, multiplier=lora_alpha_slider, device=self.device) | |
| self._active_lora_path = None | |
| return gr.update(), gr.update(), f"Error: {str(e)}" | |
| # Save output | |
| save_sample_path = self.save_outputs( | |
| False, source_frames_slider, final_video, fps=fps | |
| ) | |
| # Return input video to display it alongside output if needed? | |
| # But generate returns [result_image, result_video, infer_progress]. | |
| # The user said "load original video didn't display". | |
| # That usually refers to the input component not showing the video after upload or example selection. | |
| # Grado handles that automatically if `value` is set or user uploads. | |
| # Maybe they mean the `validation_video` component didn't show the example? | |
| # Or do they mean they want to see the processed input frames? | |
| # "load 原视频没有display 出来" -> "Loaded original video didn't display". | |
| # Likely referring to the input UI component. | |
| # If they mean they want to see it in the output area, we can't easily change the return signature without changing UI structure. | |
| # But let's ensure the input component works. | |
| return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success" | |
| def ui(GPU_memory_mode, scheduler_dict, config_path, compile_dit, weight_dtype): | |
| controller = VideoCoF_Controller( | |
| GPU_memory_mode, scheduler_dict, model_name=None, model_type="Inpaint", | |
| config_path=config_path, compile_dit=compile_dit, | |
| weight_dtype=weight_dtype | |
| ) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# VideoCoF Demo") | |
| with gr.Column(variant="panel"): | |
| # Hide model selection | |
| local_model_dir = os.path.join("models", "Wan2.1-T2V-14B") | |
| diffusion_transformer_dropdown, _ = create_model_checkpoints(controller, visible=False, default_model=local_model_dir) | |
| # Use snapshot download for the VideoCoF repo to get all weights (including safetensors) | |
| try: | |
| from huggingface_hub import snapshot_download, hf_hub_download | |
| print("Downloading Wan2.1-T2V-14B weights...") | |
| hf_model_id = "Wan-AI/Wan2.1-T2V-14B" | |
| snapshot_download(repo_id=hf_model_id, local_dir=local_model_dir, local_dir_use_symlinks=False) | |
| os.makedirs("models/Personalized_Model", exist_ok=True) | |
| print("Downloading VideoCoF weights...") | |
| default_lora_name = "videocof.safetensors" | |
| hf_hub_download(repo_id="XiangpengYang/VideoCoF", filename=default_lora_name, local_dir="models/Personalized_Model") | |
| print("Downloading FusionX Acceleration LoRA...") | |
| acc_lora_filename = "Wan2.1_Text_to_Video_14B_FusionX_LoRA.safetensors" | |
| hf_hub_download(repo_id="MonsterMMORPG/Wan_GGUF", filename=acc_lora_filename, local_dir="models/Personalized_Model") | |
| except Exception as e: | |
| print(f"Warning: Failed to pre-download weights: {e}") | |
| base_model_dropdown, lora_model_dropdown, lora_alpha_slider, _ = create_finetune_models_checkpoints( | |
| controller, visible=False, default_lora="videocof.safetensors" | |
| ) | |
| # Set default LoRA alpha to 1.0 (matching inference.py) | |
| lora_alpha_slider.value = 1.0 | |
| # Preload heavy weights and LoRAs before launching the UI to avoid first-run latency. | |
| acc_lora_path = os.path.join("models", "Personalized_Model", "Wan2.1_Text_to_Video_14B_FusionX_LoRA.safetensors") | |
| preload_models(controller, local_model_dir, "videocof.safetensors", acc_lora_path) | |
| with gr.Column(variant="panel"): | |
| prompt_textbox, negative_prompt_textbox = create_prompts(prompt="Remove the young man with short black hair wearing black shirt on the left.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| sampler_dropdown, sample_step_slider = create_samplers(controller) | |
| # Default steps lowered to 4 for acceleration | |
| sample_step_slider.value = 4 | |
| # Custom VideoCoF Params | |
| with gr.Group(): | |
| gr.Markdown("### VideoCoF Parameters") | |
| source_frames_slider = gr.Slider(label="Source Frames", minimum=1, maximum=100, value=33, step=1) | |
| reasoning_frames_slider = gr.Slider(label="Reasoning Frames", minimum=1, maximum=20, value=4, step=1) | |
| repeat_rope_checkbox = gr.Checkbox(label="Repeat RoPE", value=True) | |
| # Add Acceleration Checkbox | |
| enable_acceleration = gr.Checkbox(label="Enable 4-step Acceleration (FusionX LoRA)", value=True) | |
| # Use custom height/width creation to hide/customize | |
| resize_method, width_slider, height_slider, base_resolution = create_height_width_english( | |
| default_height=480, default_width=832, maximum_height=1344, maximum_width=1344 | |
| ) | |
| # Default video length 65 | |
| generation_method, length_slider, overlap_video_length, partial_video_length = \ | |
| create_generation_methods_and_video_length( | |
| ["Video Generation"], | |
| default_video_length=65, | |
| maximum_video_length=161 | |
| ) | |
| # Simplified input for VideoCoF - mainly Video to Video. | |
| image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method( | |
| ["Video to Video"], | |
| prompt_textbox, | |
| support_end_image=False, | |
| default_video="assets/two_man.mp4", | |
| video_examples=[ | |
| ["assets/two_man.mp4", "Remove the young man with short black hair wearing black shirt on the left."], | |
| ["assets/three_people.mp4", "Remove the man with short dark hair wearing a gray suit on the right"], | |
| ["assets/office.mp4", "Remove the beige CRT computer setup."], | |
| ["assets/woman_ballon.mp4", "Add the woman in a floral dress pointing at the balloon on the left."], | |
| ["assets/greenhouse.mp4", "A white Samoyed is watching the man, who crouches in a greenhouse. The Samoyed is covered in thick, fluffy white fur, giving it a very soft and plush appearance. Its ears are erect and triangular, making it look alert and intelligent. The Samoyed's face features its signature smile, with bright black eyes that convey friendliness and curiosity."], | |
| ["assets/gameplay.mp4", "Add the woman holding the blue game controller to the left of the man, engaged in gameplay."], | |
| ["assets/dog.mp4", "Add the brown and white beagle interacting with and drinking from the metallic bowl on the wooden floor."], | |
| ["assets/sign.mp4", "Replace the yellow \"SCHOOL\" sign with a red hospital sign, featuring a white hospital emblem on the top and the word \"HOSPITAL\" below."], | |
| ["assets/old_man.mp4", "Swap the old man with long white hair and a blue checkered shirt at the left side of the frame with a woman with curly brown hair and a denim shirt."], | |
| ["assets/pants.mp4", "swap the white pants worn by the individual the light blue jeans."], | |
| ["assets/bowl.mp4", "Make the largest cup on the right white and smooth."], | |
| ["assets/ketchup.mp4", "Make the ketchup bottle to the right of the BBQ sauce bottle violet color."], | |
| ["assets/fruit.mp4", "Make the pomegranate at the right side of the basket lavender color."] | |
| ], | |
| ) | |
| # Ensure validation_video is visible and interactive | |
| validation_video.visible = True | |
| validation_video.interactive = True | |
| # Set default seed to 0 | |
| cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(True) | |
| seed_textbox.value = "0" | |
| cfg_scale_slider.value = 1.0 | |
| generate_button = gr.Button(value="Generate", variant='primary') | |
| result_image, result_video, infer_progress = create_ui_outputs() | |
| # Event handlers | |
| generate_button.click( | |
| fn=controller.generate, | |
| inputs=[ | |
| diffusion_transformer_dropdown, | |
| base_model_dropdown, | |
| lora_model_dropdown, | |
| lora_alpha_slider, | |
| prompt_textbox, | |
| negative_prompt_textbox, | |
| sampler_dropdown, | |
| sample_step_slider, | |
| resize_method, | |
| width_slider, | |
| height_slider, | |
| base_resolution, | |
| generation_method, | |
| length_slider, | |
| overlap_video_length, | |
| partial_video_length, | |
| cfg_scale_slider, | |
| start_image, | |
| end_image, | |
| validation_video, | |
| validation_video_mask, | |
| control_video, | |
| denoise_strength, | |
| seed_textbox, | |
| ref_image, | |
| # New inputs | |
| source_frames_slider, | |
| reasoning_frames_slider, | |
| repeat_rope_checkbox, | |
| enable_acceleration | |
| ], | |
| outputs=[result_image, result_video, infer_progress] | |
| ) | |
| return demo, controller | |
| if __name__ == "__main__": | |
| from videox_fun.ui.controller import flow_scheduler_dict | |
| # Use CPU offload to reduce GPU memory footprint in Space | |
| GPU_memory_mode = "sequential_cpu_offload" | |
| compile_dit = False | |
| weight_dtype = torch.bfloat16 | |
| server_name = "0.0.0.0" | |
| server_port = 7860 | |
| config_path = "config/wan2.1/wan_civitai.yaml" | |
| demo, controller = ui(GPU_memory_mode, flow_scheduler_dict, config_path, compile_dit, weight_dtype) | |
| demo.queue(status_update_rate=1).launch( | |
| server_name=server_name, | |
| server_port=server_port, | |
| prevent_thread_lock=True, | |
| share=False | |
| ) | |
| while True: | |
| time.sleep(5) | |