import functools from dataclasses import dataclass import random import gradio as gr import spaces import torch from diffusers import WanPipeline, AutoencoderKLWan from diffusers.utils import export_to_video, load_video from vibt.wan import load_vibt_weight, encode_video from vibt.scheduler import ViBTScheduler import tempfile import os import cv2 def get_fps(path): cap = cv2.VideoCapture(path) fps = cap.get(cv2.CAP_PROP_FPS) cap.release() return fps base_model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" # vae = AutoencoderKLWan.from_pretrained( # base_model_id, subfolder="vae", torch_dtype=torch.float32 # ) pipe = WanPipeline.from_pretrained(base_model_id, torch_dtype=torch.bfloat16) pipe.to("cuda") load_vibt_weight( pipe.transformer, "Yuanshi/ViBT", "video/video_stylization.safetensors", ) pipe.scheduler = ViBTScheduler.from_scheduler(pipe.scheduler) @dataclass(frozen=True) class SliderConfig: label: str minimum: float maximum: float step: float value: float info: str @dataclass(frozen=True) class PresetConfig: shift_gamma: float steps: int guidance_scale: float GAMMA_SLIDER = SliderConfig( label="Shift Gamma", minimum=1.0, maximum=10.0, step=0.5, value=5.0, info="Scheduler adjustment parameter.", ) STEP_SLIDER = SliderConfig( label="Inference Steps", minimum=1, maximum=28, step=1, value=10, info="More steps improve quality but take longer.", ) GUIDANCE_SLIDER = SliderConfig( label="Guidance Scale (CFG)", minimum=1.0, maximum=5.0, step=0.5, value=2, info="Controls adherence to the text prompt.", ) STYLE_CHOICES = [ "Make it Illustration style.", "Make it a drawing by Van Gogh.", "Make it a pencil sketch style.", "Make it watercolor drawing style.", "Make it a Pixel Art.", "Make it a Japanese anime style, cel shading.", "Make it the style of Neon Light Art.", "Make it papercut style.", "Make it a blueprint.", "Make it Comic Book Style.", "Render the subject as a classical sculpture carved from a single block of pristine white marble.", ] EXAMPLE_INPUTS = [ ["assets/video_00000000.mp4", STYLE_CHOICES[0]], ["assets/video_00000007.mp4", STYLE_CHOICES[1]], ["assets/video_00000019.mp4", STYLE_CHOICES[2]], ["assets/video_00000071.mp4", STYLE_CHOICES[3]], ] PRESET_MODES = { "Fast": PresetConfig(shift_gamma=5.0, steps=6, guidance_scale=2), "Balanced": PresetConfig(shift_gamma=5.0, steps=10, guidance_scale=2), "Quality": PresetConfig(shift_gamma=5.0, steps=20, guidance_scale=2), } def _create_slider(config: SliderConfig) -> gr.Slider: """Helper to keep slider creation consistent.""" return gr.Slider( label=config.label, minimum=config.minimum, maximum=config.maximum, step=config.step, value=config.value, info=config.info, ) @spaces.GPU(duration=120) def run_stylization( input_video_path, prompt, shift_gamma, steps, guidance_scale, seed, randomize_seed, ): """Placeholder stylization pipeline that echoes the input video.""" if not input_video_path: return None resolved_seed = _resolve_seed(seed, randomize_seed) print("========== Inference Start ==========") print(f"Video Path: {input_video_path}") print(f"Prompt: {prompt}") print( "Params: " f"Gamma={shift_gamma}, " f"Steps={steps}, " f"CFG={guidance_scale}, " f"Seed={resolved_seed}" ) source_video = load_video(input_video_path) source_video = [each.resize((832, 480)) for each in source_video][:81] if len(source_video) < 81: source_video += [source_video[-1]] * (81 - len(source_video)) source_fps = get_fps(input_video_path) source_latents = encode_video(pipe, source_video) pipe.scheduler.set_parameters( noise_scale=1.0, shift_gamma=shift_gamma, seed=resolved_seed ) output = pipe( prompt=prompt, num_inference_steps=steps, guidance_scale=guidance_scale, latents=source_latents, ).frames[0] tmp_dir = tempfile.mkdtemp() out_path = os.path.join(tmp_dir, f"{random.randint(0, 2**31 - 1)}.mp4") export_to_video(output, out_path, fps=source_fps) print(out_path) return out_path def _resolve_seed(seed_value, randomize): """Return an integer seed, generating a random one when requested or missing.""" if randomize or seed_value in (None, ""): return random.randint(0, 2**31 - 1) return int(seed_value) def run_with_preset(input_video_path, prompt, seed, randomize_seed, preset_key): """Wrap stylization with predefined presets for quick generation.""" preset = PRESET_MODES[preset_key] return run_stylization( input_video_path=input_video_path, prompt=prompt, shift_gamma=preset.shift_gamma, steps=preset.steps, guidance_scale=preset.guidance_scale, seed=seed, randomize_seed=randomize_seed, ) def _bind_preset_button(button, preset_key, inputs, output, extra_kwargs=None): extra_kwargs = extra_kwargs or {} button.click( fn=functools.partial(run_with_preset, preset_key=preset_key, **extra_kwargs), inputs=inputs, outputs=[output], ) def build_demo() -> gr.Blocks: """Create the Gradio interface for video stylization.""" with gr.Blocks() as demo: with gr.Column(elem_id="col-container"): gr.HTML( """ """ ) gr.Markdown( """ # 🎥 ViBT: Vision Bridge Transformer at Scale
Project Page arXiv HuggingFace GitHub
""" ) with gr.Row(): with gr.Column(): input_video = gr.Video(label="Source Video", sources=["upload"]) with gr.Column(): output_video = gr.Video(label="Stylized Result", interactive=False) with gr.Row(): with gr.Column(scale=1) as control_col: prompt = gr.Dropdown( label="Style Instruction", choices=STYLE_CHOICES, value=STYLE_CHOICES[0], allow_custom_value=True, ) with gr.Tabs(): with gr.Tab("Quick Generate"): with gr.Row(): fast_btn = gr.Button( "⚡ Fast", variant="primary" ) balanced_btn = gr.Button( "🎯 Balanced", variant="primary" ) quality_btn = gr.Button( "🌟 High Quality", variant="primary" ) _bind_preset_button( button=fast_btn, preset_key="Fast", inputs=[ input_video, prompt, ], output=output_video, extra_kwargs={"seed": None, "randomize_seed": True}, ) _bind_preset_button( button=balanced_btn, preset_key="Balanced", inputs=[ input_video, prompt, ], output=output_video, extra_kwargs={"seed": None, "randomize_seed": True}, ) _bind_preset_button( button=quality_btn, preset_key="Quality", inputs=[ input_video, prompt, ], output=output_video, extra_kwargs={"seed": None, "randomize_seed": True}, ) with gr.Tab("Advanced Settings"): with gr.Row(): shift_gamma = _create_slider(GAMMA_SLIDER) guidance_scale = _create_slider(GUIDANCE_SLIDER) with gr.Row(): num_steps = _create_slider(STEP_SLIDER) randomize_seed_adv = gr.Checkbox( label="Randomize Seed", value=True, info="Checked = new random seed each run. Uncheck to provide your own seed.", ) seed_adv = gr.Number( label="Seed (used when Randomize is off)", value=42, precision=0, ) run_btn = gr.Button("Generate", variant="primary") run_btn.click( fn=run_stylization, inputs=[ input_video, prompt, shift_gamma, num_steps, guidance_scale, seed_adv, randomize_seed_adv, ], outputs=[output_video], ) with gr.Column(scale=1): gr.Examples( examples=EXAMPLE_INPUTS, inputs=[input_video, prompt], label="Example inputs", ) return demo demo = build_demo() if __name__ == "__main__": demo.launch()