ViBT / app.py
Yuanshi's picture
Update app.py
d669681 verified
import functools
from dataclasses import dataclass
import random
import gradio as gr
import spaces
import torch
from diffusers import WanPipeline, AutoencoderKLWan
from diffusers.utils import export_to_video, load_video
from vibt.wan import load_vibt_weight, encode_video
from vibt.scheduler import ViBTScheduler
import tempfile
import os
import cv2
def get_fps(path):
cap = cv2.VideoCapture(path)
fps = cap.get(cv2.CAP_PROP_FPS)
cap.release()
return fps
base_model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
# vae = AutoencoderKLWan.from_pretrained(
# base_model_id, subfolder="vae", torch_dtype=torch.float32
# )
pipe = WanPipeline.from_pretrained(base_model_id, torch_dtype=torch.bfloat16)
pipe.to("cuda")
load_vibt_weight(
pipe.transformer,
"Yuanshi/ViBT",
"video/video_stylization.safetensors",
)
pipe.scheduler = ViBTScheduler.from_scheduler(pipe.scheduler)
@dataclass(frozen=True)
class SliderConfig:
label: str
minimum: float
maximum: float
step: float
value: float
info: str
@dataclass(frozen=True)
class PresetConfig:
shift_gamma: float
steps: int
guidance_scale: float
GAMMA_SLIDER = SliderConfig(
label="Shift Gamma",
minimum=1.0,
maximum=10.0,
step=0.5,
value=5.0,
info="Scheduler adjustment parameter.",
)
STEP_SLIDER = SliderConfig(
label="Inference Steps",
minimum=1,
maximum=28,
step=1,
value=10,
info="More steps improve quality but take longer.",
)
GUIDANCE_SLIDER = SliderConfig(
label="Guidance Scale (CFG)",
minimum=1.0,
maximum=5.0,
step=0.5,
value=2,
info="Controls adherence to the text prompt.",
)
STYLE_CHOICES = [
"Make it Illustration style.",
"Make it a drawing by Van Gogh.",
"Make it a pencil sketch style.",
"Make it watercolor drawing style.",
"Make it a Pixel Art.",
"Make it a Japanese anime style, cel shading.",
"Make it the style of Neon Light Art.",
"Make it papercut style.",
"Make it a blueprint.",
"Make it Comic Book Style.",
"Render the subject as a classical sculpture carved from a single block of pristine white marble.",
]
EXAMPLE_INPUTS = [
["assets/video_00000000.mp4", STYLE_CHOICES[0]],
["assets/video_00000007.mp4", STYLE_CHOICES[1]],
["assets/video_00000019.mp4", STYLE_CHOICES[2]],
["assets/video_00000071.mp4", STYLE_CHOICES[3]],
]
PRESET_MODES = {
"Fast": PresetConfig(shift_gamma=5.0, steps=6, guidance_scale=2),
"Balanced": PresetConfig(shift_gamma=5.0, steps=10, guidance_scale=2),
"Quality": PresetConfig(shift_gamma=5.0, steps=20, guidance_scale=2),
}
def _create_slider(config: SliderConfig) -> gr.Slider:
"""Helper to keep slider creation consistent."""
return gr.Slider(
label=config.label,
minimum=config.minimum,
maximum=config.maximum,
step=config.step,
value=config.value,
info=config.info,
)
@spaces.GPU(duration=120)
def run_stylization(
input_video_path,
prompt,
shift_gamma,
steps,
guidance_scale,
seed,
randomize_seed,
):
"""Placeholder stylization pipeline that echoes the input video."""
if not input_video_path:
return None
resolved_seed = _resolve_seed(seed, randomize_seed)
print("========== Inference Start ==========")
print(f"Video Path: {input_video_path}")
print(f"Prompt: {prompt}")
print(
"Params: "
f"Gamma={shift_gamma}, "
f"Steps={steps}, "
f"CFG={guidance_scale}, "
f"Seed={resolved_seed}"
)
source_video = load_video(input_video_path)
source_video = [each.resize((832, 480)) for each in source_video][:81]
if len(source_video) < 81:
source_video += [source_video[-1]] * (81 - len(source_video))
source_fps = get_fps(input_video_path)
source_latents = encode_video(pipe, source_video)
pipe.scheduler.set_parameters(
noise_scale=1.0, shift_gamma=shift_gamma, seed=resolved_seed
)
output = pipe(
prompt=prompt,
num_inference_steps=steps,
guidance_scale=guidance_scale,
latents=source_latents,
).frames[0]
tmp_dir = tempfile.mkdtemp()
out_path = os.path.join(tmp_dir, f"{random.randint(0, 2**31 - 1)}.mp4")
export_to_video(output, out_path, fps=source_fps)
print(out_path)
return out_path
def _resolve_seed(seed_value, randomize):
"""Return an integer seed, generating a random one when requested or missing."""
if randomize or seed_value in (None, ""):
return random.randint(0, 2**31 - 1)
return int(seed_value)
def run_with_preset(input_video_path, prompt, seed, randomize_seed, preset_key):
"""Wrap stylization with predefined presets for quick generation."""
preset = PRESET_MODES[preset_key]
return run_stylization(
input_video_path=input_video_path,
prompt=prompt,
shift_gamma=preset.shift_gamma,
steps=preset.steps,
guidance_scale=preset.guidance_scale,
seed=seed,
randomize_seed=randomize_seed,
)
def _bind_preset_button(button, preset_key, inputs, output, extra_kwargs=None):
extra_kwargs = extra_kwargs or {}
button.click(
fn=functools.partial(run_with_preset, preset_key=preset_key, **extra_kwargs),
inputs=inputs,
outputs=[output],
)
def build_demo() -> gr.Blocks:
"""Create the Gradio interface for video stylization."""
with gr.Blocks() as demo:
with gr.Column(elem_id="col-container"):
gr.HTML(
"""
<style>
#col-container { max-width: 1200px; margin: 0 auto; }
</style>
"""
)
gr.Markdown(
"""
# 🎥 ViBT: Vision Bridge Transformer at Scale
<div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
<a href="https://yuanshi9815.github.io/ViBT_homepage"><img src="https://img.shields.io/badge/Web-Project Page-1d72b8.svg" alt="Project Page"></a>
<a href="https://arxiv.org/abs/2511.23199"><img src="https://img.shields.io/badge/ariXv-Paper-A42C25.svg" alt="arXiv"></a>
<a href="https://huggingface.co/Yuanshi/ViBT"><img src="https://img.shields.io/badge/🤗Huggingface-Model-ffbd45.svg" alt="HuggingFace"></a>
<a href="https://github.com/Yuanshi9815/ViBT"><img src="https://img.shields.io/badge/GitHub-Code-blue.svg?logo=github&" alt="GitHub"></a>
</div>
"""
)
with gr.Row():
with gr.Column():
input_video = gr.Video(label="Source Video", sources=["upload"])
with gr.Column():
output_video = gr.Video(label="Stylized Result", interactive=False)
with gr.Row():
with gr.Column(scale=1) as control_col:
prompt = gr.Dropdown(
label="Style Instruction",
choices=STYLE_CHOICES,
value=STYLE_CHOICES[0],
allow_custom_value=True,
)
with gr.Tabs():
with gr.Tab("Quick Generate"):
with gr.Row():
fast_btn = gr.Button(
"⚡ Fast", variant="primary"
)
balanced_btn = gr.Button(
"🎯 Balanced", variant="primary"
)
quality_btn = gr.Button(
"🌟 High Quality", variant="primary"
)
_bind_preset_button(
button=fast_btn,
preset_key="Fast",
inputs=[
input_video,
prompt,
],
output=output_video,
extra_kwargs={"seed": None, "randomize_seed": True},
)
_bind_preset_button(
button=balanced_btn,
preset_key="Balanced",
inputs=[
input_video,
prompt,
],
output=output_video,
extra_kwargs={"seed": None, "randomize_seed": True},
)
_bind_preset_button(
button=quality_btn,
preset_key="Quality",
inputs=[
input_video,
prompt,
],
output=output_video,
extra_kwargs={"seed": None, "randomize_seed": True},
)
with gr.Tab("Advanced Settings"):
with gr.Row():
shift_gamma = _create_slider(GAMMA_SLIDER)
guidance_scale = _create_slider(GUIDANCE_SLIDER)
with gr.Row():
num_steps = _create_slider(STEP_SLIDER)
randomize_seed_adv = gr.Checkbox(
label="Randomize Seed",
value=True,
info="Checked = new random seed each run. Uncheck to provide your own seed.",
)
seed_adv = gr.Number(
label="Seed (used when Randomize is off)",
value=42,
precision=0,
)
run_btn = gr.Button("Generate", variant="primary")
run_btn.click(
fn=run_stylization,
inputs=[
input_video,
prompt,
shift_gamma,
num_steps,
guidance_scale,
seed_adv,
randomize_seed_adv,
],
outputs=[output_video],
)
with gr.Column(scale=1):
gr.Examples(
examples=EXAMPLE_INPUTS,
inputs=[input_video, prompt],
label="Example inputs",
)
return demo
demo = build_demo()
if __name__ == "__main__":
demo.launch()