File size: 16,709 Bytes
42a2bfa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
import os
import sys
import time
import torch
import gradio as gr
import numpy as np
import imageio
from PIL import Image

# Add project root to path
# current_file_path = os.path.abspath(__file__)
# project_root = os.path.dirname(os.path.dirname(current_file_path))
# if project_root not in sys.path:
#    sys.path.insert(0, project_root)

from videox_fun.ui.wan_ui import Wan_Controller, css
from videox_fun.ui.ui import (
    create_model_type, create_model_checkpoints, create_finetune_models_checkpoints,
    create_teacache_params, create_cfg_skip_params, create_cfg_riflex_k,
    create_prompts, create_samplers, create_height_width,
    create_generation_methods_and_video_length, create_generation_method,
    create_cfg_and_seedbox, create_ui_outputs
)
from videox_fun.data.dataset_image_video import derive_ground_object_from_instruction
from videox_fun.utils.lora_utils import merge_lora, unmerge_lora
from videox_fun.utils.utils import save_videos_grid, timer

# Redefine create_height_width to remove Chinese and specific defaults if needed, 
# although we will mostly ignore sliders if we use input resolution.
# We will create a custom version here to avoid modifying the library file if possible,
# or we just rely on `create_height_width` and update labels.
# But `create_height_width` is imported. Let's override it or create a new one.

def create_height_width_english(default_height, default_width, maximum_height, maximum_width):
    resize_method = gr.Radio(
        ["Generate by", "Resize according to Reference"],
        value="Generate by",
        show_label=False,
        visible=False # Hide since we force input resolution
    )
    # We keep sliders visible but maybe we can update them dynamically or just ignore them?
    # User requested "input is whatever resolution, inference is whatever resolution".
    # So we can hide these or just label them as "Default / Override if no video".
    # But better to hide them if we always use video resolution.
    # However, if no video is provided (which shouldn't happen for VideoCoF), we might need them.
    # Let's keep them but make them less prominent or explain.
    # Actually user said "no default 480x832", implying don't force it.
    
    width_slider     = gr.Slider(label="Width", value=default_width, minimum=128, maximum=maximum_width, step=16, visible=False)
    height_slider    = gr.Slider(label="Height", value=default_height, minimum=128, maximum=maximum_height, step=16, visible=False)
    base_resolution  = gr.Radio(label="Base Resolution", value=512, choices=[512, 640, 768, 896, 960, 1024], visible=False)

    return resize_method, width_slider, height_slider, base_resolution

def load_video_frames(video_path: str, source_frames: int):
    assert source_frames is not None, "source_frames is required"
    
    reader = imageio.get_reader(video_path)
    try:
        total_frames = reader.count_frames()
    except Exception:
        total_frames = sum(1 for _ in reader)
        reader = imageio.get_reader(video_path)

    stride = max(1, total_frames // source_frames)
    # Using random start frame as in inference.py
    start_frame = torch.randint(0, max(1, total_frames - stride * source_frames), (1,))[0].item()

    frames = []
    original_height, original_width = None, None

    for i in range(source_frames):
        idx = start_frame + i * stride
        if idx >= total_frames:
            break
        try:
            frame = reader.get_data(idx)
            pil_frame = Image.fromarray(frame)
            if original_height is None:
                original_width, original_height = pil_frame.size
            frames.append(pil_frame)
        except IndexError:
            break

    reader.close()

    while len(frames) < source_frames:
        if frames:
            frames.append(frames[-1].copy())
        else:
            w, h = (original_width, original_height) if original_width else (832, 480)
            frames.append(Image.new('RGB', (w, h), (0, 0, 0)))

    input_video = torch.from_numpy(np.array(frames))
    input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0).float()
    input_video = input_video * (2.0 / 255.0) - 1.0

    return input_video, original_height, original_width

class VideoCoF_Controller(Wan_Controller):
    @timer
    def generate(
        self,
        diffusion_transformer_dropdown,
        base_model_dropdown,
        lora_model_dropdown, 
        lora_alpha_slider,
        prompt_textbox, 
        negative_prompt_textbox, 
        sampler_dropdown, 
        sample_step_slider, 
        resize_method,
        width_slider, 
        height_slider, 
        base_resolution, 
        generation_method, 
        length_slider, 
        overlap_video_length, 
        partial_video_length, 
        cfg_scale_slider, 
        start_image, 
        end_image, 
        validation_video,
        validation_video_mask,
        control_video,
        denoise_strength,
        seed_textbox,
        ref_image=None,
        enable_teacache=None, 
        teacache_threshold=None, 
        num_skip_start_steps=None, 
        teacache_offload=None, 
        cfg_skip_ratio=None,
        enable_riflex=None, 
        riflex_k=None,
        # Custom args
        source_frames_slider=33,
        reasoning_frames_slider=4,
        repeat_rope_checkbox=True,
        fps=10,
        is_api=False,
    ):
        self.clear_cache()
        print(f"VideoCoF Generation started.")

        if self.diffusion_transformer_dropdown != diffusion_transformer_dropdown:
            self.update_diffusion_transformer(diffusion_transformer_dropdown)

        if self.base_model_path != base_model_dropdown:
            self.update_base_model(base_model_dropdown)

        if self.lora_model_path != lora_model_dropdown:
            self.update_lora_model(lora_model_dropdown)

        # Scheduler setup
        scheduler_config = self.pipeline.scheduler.config
        if sampler_dropdown in ["Flow_Unipc", "Flow_DPM++"]:
            scheduler_config['shift'] = 1
        self.pipeline.scheduler = self.scheduler_dict[sampler_dropdown].from_config(scheduler_config)

        # LoRA merging
        if self.lora_model_path != "none":
            print(f"Merge Lora.")
            self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)

        # Seed
        if int(seed_textbox) != -1 and seed_textbox != "": 
            torch.manual_seed(int(seed_textbox))
        else: 
            seed_textbox = np.random.randint(0, 1e10)
        generator = torch.Generator(device=self.device).manual_seed(int(seed_textbox))

        try:
            # VideoCoF logic
            # Use validation_video as source if provided (UI standard for Video-to-Video)
            input_video_path = validation_video
            
            if input_video_path is None:
                # Fallback to control_video if set, but standard UI uses validation_video
                input_video_path = control_video

            if input_video_path is None:
                raise ValueError("Please upload a video for VideoCoF generation.")

            # CoT Prompt Construction
            edit_text = prompt_textbox
            ground_instr = derive_ground_object_from_instruction(edit_text)
            prompt = (
                "A video sequence showing three parts: first the original scene, "
                f"then grounded {ground_instr}, and finally the same scene but {edit_text}"
            )
            print(f"Constructed prompt: {prompt}")

            # Load video frames
            input_video_tensor, video_height, video_width = load_video_frames(
                input_video_path,
                source_frames=source_frames_slider
            )

            # Using loaded video dimensions
            h, w = video_height, video_width
            print(f"Input video dimensions: {w}x{h}")

            print(f"Running pipeline with frames={length_slider}, source={source_frames_slider}, reasoning={reasoning_frames_slider}")
            
            sample = self.pipeline(
                video=input_video_tensor,
                prompt=prompt,
                num_frames=length_slider,
                source_frames=source_frames_slider,
                reasoning_frames=reasoning_frames_slider,
                negative_prompt=negative_prompt_textbox,
                height=h,
                width=w,
                generator=generator,
                guidance_scale=cfg_scale_slider,
                num_inference_steps=sample_step_slider,
                repeat_rope=repeat_rope_checkbox,
                cot=True,
            ).videos

            final_video = sample

        except Exception as e:
            print(f"Error: {e}")
            if self.lora_model_path != "none":
                 self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
            return gr.update(), gr.update(), f"Error: {str(e)}"

        # Unmerge LoRA
        if self.lora_model_path != "none":
            self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)

        # Save output
        save_sample_path = self.save_outputs(
            False, length_slider, final_video, fps=fps
        )
        
        # Return input video to display it alongside output if needed? 
        # But generate returns [result_image, result_video, infer_progress].
        # The user said "load original video didn't display".
        # That usually refers to the input component not showing the video after upload or example selection.
        # Grado handles that automatically if `value` is set or user uploads.
        # Maybe they mean the `validation_video` component didn't show the example?
        # Or do they mean they want to see the processed input frames?
        # "load 原视频没有display 出来" -> "Loaded original video didn't display".
        # Likely referring to the input UI component.
        # If they mean they want to see it in the output area, we can't easily change the return signature without changing UI structure.
        # But let's ensure the input component works.
        
        return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success"

def ui(GPU_memory_mode, scheduler_dict, config_path, compile_dit, weight_dtype):
    controller = VideoCoF_Controller(
        GPU_memory_mode, scheduler_dict, model_name=None, model_type="Inpaint", 
        config_path=config_path, compile_dit=compile_dit,
        weight_dtype=weight_dtype
    )

    with gr.Blocks() as demo:
        gr.Markdown("# VideoCoF Demo")
        
        with gr.Column(variant="panel"):
            # Hide model selection
            diffusion_transformer_dropdown, _ = create_model_checkpoints(controller, visible=False, default_model="Wan-AI/Wan2.1-T2V-14B")
            base_model_dropdown, lora_model_dropdown, lora_alpha_slider, _ = create_finetune_models_checkpoints(controller, visible=False, default_lora="XiangpengYang/VideoCoF")
            
            # Set default LoRA alpha to 1.0 (matching inference.py)
            lora_alpha_slider.value = 1.0

            with gr.Row():
                # Disable teacache by default
                enable_teacache, teacache_threshold, num_skip_start_steps, teacache_offload = create_teacache_params(False, 0.10, 5, False)
                cfg_skip_ratio = create_cfg_skip_params(0)
                enable_riflex, riflex_k = create_cfg_riflex_k(False, 6)

        with gr.Column(variant="panel"):
            prompt_textbox, negative_prompt_textbox = create_prompts(prompt="Remove the young man with short black hair wearing black shirt on the left.")
            
            with gr.Row():
                with gr.Column():
                    sampler_dropdown, sample_step_slider = create_samplers(controller)
                    
                    # Custom VideoCoF Params
                    with gr.Group():
                        gr.Markdown("### VideoCoF Parameters")
                        source_frames_slider = gr.Slider(label="Source Frames", minimum=1, maximum=100, value=33, step=1)
                        reasoning_frames_slider = gr.Slider(label="Reasoning Frames", minimum=1, maximum=20, value=4, step=1)
                        repeat_rope_checkbox = gr.Checkbox(label="Repeat RoPE", value=True)
                        
                    # Use custom height/width creation to hide/customize
                    resize_method, width_slider, height_slider, base_resolution = create_height_width_english(
                        default_height=480, default_width=832, maximum_height=1344, maximum_width=1344
                    )
                    
                    # Default video length 65
                    generation_method, length_slider, overlap_video_length, partial_video_length = \
                        create_generation_methods_and_video_length(
                            ["Video Generation"], 
                            default_video_length=65, 
                            maximum_video_length=161
                        )
                    
                    # Simplified input for VideoCoF - mainly Video to Video.
                    image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method(
                        ["Video to Video"], prompt_textbox, support_end_image=False, default_video="assets/two_man.mp4",
                        video_examples=[
                            ["assets/two_man.mp4", "Remove the young man with short black hair wearing black shirt on the left."],
                            ["assets/sign.mp4", "Replace the yellow \"SCHOOL\" sign with a red hospital sign, featuring a white hospital emblem on the top and the word \"HOSPITAL\" below."]
                        ]
                    )
                    
                    # Ensure validation_video is visible and interactive
                    validation_video.visible = True
                    validation_video.interactive = True

                    # Set default seed to 0
                    cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(True)
                    seed_textbox.value = "0"
                    
                    generate_button = gr.Button(value="Generate", variant='primary')

                result_image, result_video, infer_progress = create_ui_outputs()

        # Event handlers
        generate_button.click(
            fn=controller.generate,
            inputs=[
                diffusion_transformer_dropdown,
                base_model_dropdown,
                lora_model_dropdown, 
                lora_alpha_slider,
                prompt_textbox, 
                negative_prompt_textbox, 
                sampler_dropdown, 
                sample_step_slider, 
                resize_method,
                width_slider, 
                height_slider, 
                base_resolution, 
                generation_method, 
                length_slider, 
                overlap_video_length, 
                partial_video_length, 
                cfg_scale_slider, 
                start_image, 
                end_image, 
                validation_video,
                validation_video_mask,
                control_video,
                denoise_strength, 
                seed_textbox,
                ref_image, 
                enable_teacache, 
                teacache_threshold, 
                num_skip_start_steps, 
                teacache_offload, 
                cfg_skip_ratio,
                enable_riflex, 
                riflex_k,
                # New inputs
                source_frames_slider,
                reasoning_frames_slider,
                repeat_rope_checkbox
            ],
            outputs=[result_image, result_video, infer_progress]
        )

    return demo, controller

if __name__ == "__main__":
    from videox_fun.ui.controller import flow_scheduler_dict
    
    GPU_memory_mode = "sequential_cpu_offload"
    compile_dit = False
    weight_dtype = torch.bfloat16
    server_name = "0.0.0.0"
    server_port = 7860
    config_path = "config/wan2.1/wan_civitai.yaml"

    demo, controller = ui(GPU_memory_mode, flow_scheduler_dict, config_path, compile_dit, weight_dtype)
    
    demo.queue(status_update_rate=1).launch(
        server_name=server_name,
        server_port=server_port,
        prevent_thread_lock=True,
        share=False
    )
    
    while True:
        time.sleep(5)