|
|
import functools |
|
|
from dataclasses import dataclass |
|
|
import random |
|
|
|
|
|
import gradio as gr |
|
|
import spaces |
|
|
|
|
|
import torch |
|
|
from diffusers import WanPipeline, AutoencoderKLWan |
|
|
from diffusers.utils import export_to_video, load_video |
|
|
from vibt.wan import load_vibt_weight, encode_video |
|
|
from vibt.scheduler import ViBTScheduler |
|
|
import tempfile |
|
|
import os |
|
|
import cv2 |
|
|
|
|
|
|
|
|
def get_fps(path): |
|
|
cap = cv2.VideoCapture(path) |
|
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
|
cap.release() |
|
|
return fps |
|
|
|
|
|
|
|
|
base_model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" |
|
|
|
|
|
|
|
|
|
|
|
pipe = WanPipeline.from_pretrained(base_model_id, torch_dtype=torch.bfloat16) |
|
|
pipe.to("cuda") |
|
|
load_vibt_weight( |
|
|
pipe.transformer, |
|
|
"Yuanshi/ViBT", |
|
|
"video/video_stylization.safetensors", |
|
|
) |
|
|
pipe.scheduler = ViBTScheduler.from_scheduler(pipe.scheduler) |
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
|
class SliderConfig: |
|
|
label: str |
|
|
minimum: float |
|
|
maximum: float |
|
|
step: float |
|
|
value: float |
|
|
info: str |
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
|
class PresetConfig: |
|
|
shift_gamma: float |
|
|
steps: int |
|
|
guidance_scale: float |
|
|
|
|
|
|
|
|
GAMMA_SLIDER = SliderConfig( |
|
|
label="Shift Gamma", |
|
|
minimum=1.0, |
|
|
maximum=10.0, |
|
|
step=0.5, |
|
|
value=5.0, |
|
|
info="Scheduler adjustment parameter.", |
|
|
) |
|
|
|
|
|
STEP_SLIDER = SliderConfig( |
|
|
label="Inference Steps", |
|
|
minimum=1, |
|
|
maximum=28, |
|
|
step=1, |
|
|
value=10, |
|
|
info="More steps improve quality but take longer.", |
|
|
) |
|
|
|
|
|
GUIDANCE_SLIDER = SliderConfig( |
|
|
label="Guidance Scale (CFG)", |
|
|
minimum=1.0, |
|
|
maximum=5.0, |
|
|
step=0.5, |
|
|
value=2, |
|
|
info="Controls adherence to the text prompt.", |
|
|
) |
|
|
|
|
|
|
|
|
STYLE_CHOICES = [ |
|
|
"Make it Illustration style.", |
|
|
"Make it a drawing by Van Gogh.", |
|
|
"Make it a pencil sketch style.", |
|
|
"Make it watercolor drawing style.", |
|
|
"Make it a Pixel Art.", |
|
|
"Make it a Japanese anime style, cel shading.", |
|
|
"Make it the style of Neon Light Art.", |
|
|
"Make it papercut style.", |
|
|
"Make it a blueprint.", |
|
|
"Make it Comic Book Style.", |
|
|
"Render the subject as a classical sculpture carved from a single block of pristine white marble.", |
|
|
] |
|
|
|
|
|
|
|
|
EXAMPLE_INPUTS = [ |
|
|
["assets/video_00000000.mp4", STYLE_CHOICES[0]], |
|
|
["assets/video_00000007.mp4", STYLE_CHOICES[1]], |
|
|
["assets/video_00000019.mp4", STYLE_CHOICES[2]], |
|
|
["assets/video_00000071.mp4", STYLE_CHOICES[3]], |
|
|
] |
|
|
|
|
|
|
|
|
PRESET_MODES = { |
|
|
"Fast": PresetConfig(shift_gamma=5.0, steps=6, guidance_scale=2), |
|
|
"Balanced": PresetConfig(shift_gamma=5.0, steps=10, guidance_scale=2), |
|
|
"Quality": PresetConfig(shift_gamma=5.0, steps=20, guidance_scale=2), |
|
|
} |
|
|
|
|
|
|
|
|
def _create_slider(config: SliderConfig) -> gr.Slider: |
|
|
"""Helper to keep slider creation consistent.""" |
|
|
return gr.Slider( |
|
|
label=config.label, |
|
|
minimum=config.minimum, |
|
|
maximum=config.maximum, |
|
|
step=config.step, |
|
|
value=config.value, |
|
|
info=config.info, |
|
|
) |
|
|
|
|
|
|
|
|
@spaces.GPU(duration=120) |
|
|
def run_stylization( |
|
|
input_video_path, |
|
|
prompt, |
|
|
shift_gamma, |
|
|
steps, |
|
|
guidance_scale, |
|
|
seed, |
|
|
randomize_seed, |
|
|
): |
|
|
"""Placeholder stylization pipeline that echoes the input video.""" |
|
|
if not input_video_path: |
|
|
return None |
|
|
|
|
|
resolved_seed = _resolve_seed(seed, randomize_seed) |
|
|
|
|
|
print("========== Inference Start ==========") |
|
|
print(f"Video Path: {input_video_path}") |
|
|
print(f"Prompt: {prompt}") |
|
|
print( |
|
|
"Params: " |
|
|
f"Gamma={shift_gamma}, " |
|
|
f"Steps={steps}, " |
|
|
f"CFG={guidance_scale}, " |
|
|
f"Seed={resolved_seed}" |
|
|
) |
|
|
|
|
|
source_video = load_video(input_video_path) |
|
|
source_video = [each.resize((832, 480)) for each in source_video][:81] |
|
|
if len(source_video) < 81: |
|
|
source_video += [source_video[-1]] * (81 - len(source_video)) |
|
|
source_fps = get_fps(input_video_path) |
|
|
|
|
|
source_latents = encode_video(pipe, source_video) |
|
|
|
|
|
pipe.scheduler.set_parameters( |
|
|
noise_scale=1.0, shift_gamma=shift_gamma, seed=resolved_seed |
|
|
) |
|
|
|
|
|
output = pipe( |
|
|
prompt=prompt, |
|
|
num_inference_steps=steps, |
|
|
guidance_scale=guidance_scale, |
|
|
latents=source_latents, |
|
|
).frames[0] |
|
|
|
|
|
tmp_dir = tempfile.mkdtemp() |
|
|
out_path = os.path.join(tmp_dir, f"{random.randint(0, 2**31 - 1)}.mp4") |
|
|
export_to_video(output, out_path, fps=source_fps) |
|
|
print(out_path) |
|
|
return out_path |
|
|
|
|
|
|
|
|
def _resolve_seed(seed_value, randomize): |
|
|
"""Return an integer seed, generating a random one when requested or missing.""" |
|
|
if randomize or seed_value in (None, ""): |
|
|
return random.randint(0, 2**31 - 1) |
|
|
return int(seed_value) |
|
|
|
|
|
|
|
|
def run_with_preset(input_video_path, prompt, seed, randomize_seed, preset_key): |
|
|
"""Wrap stylization with predefined presets for quick generation.""" |
|
|
preset = PRESET_MODES[preset_key] |
|
|
return run_stylization( |
|
|
input_video_path=input_video_path, |
|
|
prompt=prompt, |
|
|
shift_gamma=preset.shift_gamma, |
|
|
steps=preset.steps, |
|
|
guidance_scale=preset.guidance_scale, |
|
|
seed=seed, |
|
|
randomize_seed=randomize_seed, |
|
|
) |
|
|
|
|
|
|
|
|
def _bind_preset_button(button, preset_key, inputs, output, extra_kwargs=None): |
|
|
extra_kwargs = extra_kwargs or {} |
|
|
button.click( |
|
|
fn=functools.partial(run_with_preset, preset_key=preset_key, **extra_kwargs), |
|
|
inputs=inputs, |
|
|
outputs=[output], |
|
|
) |
|
|
|
|
|
|
|
|
def build_demo() -> gr.Blocks: |
|
|
"""Create the Gradio interface for video stylization.""" |
|
|
with gr.Blocks() as demo: |
|
|
with gr.Column(elem_id="col-container"): |
|
|
gr.HTML( |
|
|
""" |
|
|
<style> |
|
|
#col-container { max-width: 1200px; margin: 0 auto; } |
|
|
</style> |
|
|
""" |
|
|
) |
|
|
gr.Markdown( |
|
|
""" |
|
|
# 🎥 ViBT: Vision Bridge Transformer at Scale |
|
|
<div style="text-align: center; display: flex; justify-content: left; gap: 5px;"> |
|
|
<a href="https://yuanshi9815.github.io/ViBT_homepage"><img src="https://img.shields.io/badge/Web-Project Page-1d72b8.svg" alt="Project Page"></a> |
|
|
<a href="https://arxiv.org/abs/2511.23199"><img src="https://img.shields.io/badge/ariXv-Paper-A42C25.svg" alt="arXiv"></a> |
|
|
<a href="https://huggingface.co/Yuanshi/ViBT"><img src="https://img.shields.io/badge/🤗Huggingface-Model-ffbd45.svg" alt="HuggingFace"></a> |
|
|
<a href="https://github.com/Yuanshi9815/ViBT"><img src="https://img.shields.io/badge/GitHub-Code-blue.svg?logo=github&" alt="GitHub"></a> |
|
|
</div> |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
input_video = gr.Video(label="Source Video", sources=["upload"]) |
|
|
with gr.Column(): |
|
|
output_video = gr.Video(label="Stylized Result", interactive=False) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1) as control_col: |
|
|
prompt = gr.Dropdown( |
|
|
label="Style Instruction", |
|
|
choices=STYLE_CHOICES, |
|
|
value=STYLE_CHOICES[0], |
|
|
allow_custom_value=True, |
|
|
) |
|
|
with gr.Tabs(): |
|
|
with gr.Tab("Quick Generate"): |
|
|
with gr.Row(): |
|
|
fast_btn = gr.Button( |
|
|
"⚡ Fast", variant="primary" |
|
|
) |
|
|
balanced_btn = gr.Button( |
|
|
"🎯 Balanced", variant="primary" |
|
|
) |
|
|
quality_btn = gr.Button( |
|
|
"🌟 High Quality", variant="primary" |
|
|
) |
|
|
|
|
|
_bind_preset_button( |
|
|
button=fast_btn, |
|
|
preset_key="Fast", |
|
|
inputs=[ |
|
|
input_video, |
|
|
prompt, |
|
|
], |
|
|
output=output_video, |
|
|
extra_kwargs={"seed": None, "randomize_seed": True}, |
|
|
) |
|
|
_bind_preset_button( |
|
|
button=balanced_btn, |
|
|
preset_key="Balanced", |
|
|
inputs=[ |
|
|
input_video, |
|
|
prompt, |
|
|
], |
|
|
output=output_video, |
|
|
extra_kwargs={"seed": None, "randomize_seed": True}, |
|
|
) |
|
|
_bind_preset_button( |
|
|
button=quality_btn, |
|
|
preset_key="Quality", |
|
|
inputs=[ |
|
|
input_video, |
|
|
prompt, |
|
|
], |
|
|
output=output_video, |
|
|
extra_kwargs={"seed": None, "randomize_seed": True}, |
|
|
) |
|
|
|
|
|
with gr.Tab("Advanced Settings"): |
|
|
with gr.Row(): |
|
|
shift_gamma = _create_slider(GAMMA_SLIDER) |
|
|
guidance_scale = _create_slider(GUIDANCE_SLIDER) |
|
|
|
|
|
with gr.Row(): |
|
|
num_steps = _create_slider(STEP_SLIDER) |
|
|
randomize_seed_adv = gr.Checkbox( |
|
|
label="Randomize Seed", |
|
|
value=True, |
|
|
info="Checked = new random seed each run. Uncheck to provide your own seed.", |
|
|
) |
|
|
|
|
|
seed_adv = gr.Number( |
|
|
label="Seed (used when Randomize is off)", |
|
|
value=42, |
|
|
precision=0, |
|
|
) |
|
|
|
|
|
run_btn = gr.Button("Generate", variant="primary") |
|
|
|
|
|
run_btn.click( |
|
|
fn=run_stylization, |
|
|
inputs=[ |
|
|
input_video, |
|
|
prompt, |
|
|
shift_gamma, |
|
|
num_steps, |
|
|
guidance_scale, |
|
|
seed_adv, |
|
|
randomize_seed_adv, |
|
|
], |
|
|
outputs=[output_video], |
|
|
) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.Examples( |
|
|
examples=EXAMPLE_INPUTS, |
|
|
inputs=[input_video, prompt], |
|
|
label="Example inputs", |
|
|
) |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
demo = build_demo() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|