ViBT / app.py
Yuanshi's picture
add AutoencoderKLWan
bda8013
raw
history blame
11.2 kB
import functools
from dataclasses import dataclass
import random
import gradio as gr
import spaces
import torch
from diffusers import WanPipeline, AutoencoderKLWan
from diffusers.utils import export_to_video, load_video
from vibt.wan import load_vibt_weight, encode_video
from vibt.scheduler import ViBTScheduler
import tempfile
import os
import cv2
def get_fps(path):
cap = cv2.VideoCapture(path)
fps = cap.get(cv2.CAP_PROP_FPS)
cap.release()
return fps
base_model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
vae = AutoencoderKLWan.from_pretrained(
base_model_id, subfolder="vae", torch_dtype=torch.float32
)
pipe = WanPipeline.from_pretrained(base_model_id, vae=vae, torch_dtype=torch.bfloat16)
pipe.to("cuda")
load_vibt_weight(
pipe.transformer,
"Yuanshi/ViBT",
"video/video_stylization.safetensors",
)
pipe.scheduler = ViBTScheduler.from_scheduler(pipe.scheduler)
@dataclass(frozen=True)
class SliderConfig:
label: str
minimum: float
maximum: float
step: float
value: float
info: str
@dataclass(frozen=True)
class PresetConfig:
shift_gamma: float
steps: int
guidance_scale: float
GAMMA_SLIDER = SliderConfig(
label="Shift Gamma",
minimum=1.0,
maximum=10.0,
step=0.5,
value=5.0,
info="Scheduler adjustment parameter.",
)
STEP_SLIDER = SliderConfig(
label="Inference Steps",
minimum=6,
maximum=28,
step=1,
value=10,
info="More steps improve quality but take longer.",
)
GUIDANCE_SLIDER = SliderConfig(
label="Guidance Scale (CFG)",
minimum=1.0,
maximum=5.0,
step=0.5,
value=2,
info="Controls adherence to the text prompt.",
)
STYLE_CHOICES = [
"Make it Illustration style.",
"Make it a drawing by Van Gogh.",
"Make it a pencil sketch style.",
"Make it watercolor drawing style.",
"Make it a Pixel Art.",
"Make it a Japanese anime style, cel shading.",
"Make it the style of Neon Light Art.",
"Make it papercut style.",
"Make it a blueprint.",
"Make it Comic Book Style.",
"Render the subject as a classical sculpture carved from a single block of pristine white marble.",
]
EXAMPLE_INPUTS = [
["assets/video_00000000.mp4", STYLE_CHOICES[0]],
["assets/video_00000007.mp4", STYLE_CHOICES[1]],
["assets/video_00000019.mp4", STYLE_CHOICES[2]],
["assets/video_00000071.mp4", STYLE_CHOICES[3]],
]
PRESET_MODES = {
"Fast": PresetConfig(shift_gamma=5.0, steps=6, guidance_scale=2),
"Balanced": PresetConfig(shift_gamma=5.0, steps=10, guidance_scale=2),
"Quality": PresetConfig(shift_gamma=5.0, steps=20, guidance_scale=2),
}
def _create_slider(config: SliderConfig) -> gr.Slider:
"""Helper to keep slider creation consistent."""
return gr.Slider(
label=config.label,
minimum=config.minimum,
maximum=config.maximum,
step=config.step,
value=config.value,
info=config.info,
)
@spaces.GPU
def run_stylization(
input_video_path,
prompt,
shift_gamma,
steps,
guidance_scale,
seed,
randomize_seed,
):
"""Placeholder stylization pipeline that echoes the input video."""
if not input_video_path:
return None
resolved_seed = _resolve_seed(seed, randomize_seed)
print("========== Inference Start ==========")
print(f"Video Path: {input_video_path}")
print(f"Prompt: {prompt}")
print(
"Params: "
f"Gamma={shift_gamma}, "
f"Steps={steps}, "
f"CFG={guidance_scale}, "
f"Seed={resolved_seed}"
)
source_video = load_video(input_video_path)
source_video = [each.resize((832, 480)) for each in source_video][:81]
if len(source_video) < 81:
source_video += [source_video[-1]] * (81 - len(source_video))
source_fps = get_fps(input_video_path)
source_latents = encode_video(pipe, source_video)
pipe.scheduler.set_parameters(
noise_scale=1.0, shift_gamma=shift_gamma, seed=resolved_seed
)
output = pipe(
prompt=prompt,
num_inference_steps=steps,
guidance_scale=guidance_scale,
latents=source_latents,
).frames[0]
tmp_dir = tempfile.mkdtemp()
out_path = os.path.join(tmp_dir, f"{random.randint(0, 2**31 - 1)}.mp4")
export_to_video(output, out_path, fps=source_fps)
print(out_path)
return out_path
def _resolve_seed(seed_value, randomize):
"""Return an integer seed, generating a random one when requested or missing."""
if randomize or seed_value in (None, ""):
return random.randint(0, 2**31 - 1)
return int(seed_value)
def run_with_preset(input_video_path, prompt, seed, randomize_seed, preset_key):
"""Wrap stylization with predefined presets for quick generation."""
preset = PRESET_MODES[preset_key]
return run_stylization(
input_video_path=input_video_path,
prompt=prompt,
shift_gamma=preset.shift_gamma,
steps=preset.steps,
guidance_scale=preset.guidance_scale,
seed=seed,
randomize_seed=randomize_seed,
)
def _bind_preset_button(button, preset_key, inputs, output, extra_kwargs=None):
extra_kwargs = extra_kwargs or {}
button.click(
fn=functools.partial(run_with_preset, preset_key=preset_key, **extra_kwargs),
inputs=inputs,
outputs=[output],
)
def build_demo() -> gr.Blocks:
"""Create the Gradio interface for video stylization."""
with gr.Blocks() as demo:
with gr.Column(elem_id="col-container"):
gr.HTML(
"""
<style>
#col-container { max-width: 1200px; margin: 0 auto; }
</style>
"""
)
gr.Markdown(
"""
# 🎥 ViBT: Vision Bridge Transformer at Scale
<div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
<a href="https://arxiv.org/abs/2411.15098"><img src="https://img.shields.io/badge/ariXv-Paper-A42C25.svg" alt="arXiv"></a>
<a href="https://huggingface.co/spaces/Yuanshi/OminiControl"><img src="https://img.shields.io/badge/🤗OminiControl-Demo-ffbd45.svg" alt="HuggingFace"></a>
<a href="https://github.com/Yuanshi9815/OminiControl"><img src="https://img.shields.io/badge/GitHub-Code-blue.svg?logo=github&" alt="GitHub"></a>
</div>
"""
)
with gr.Row():
with gr.Column():
input_video = gr.Video(label="Source Video", sources=["upload"])
with gr.Column():
output_video = gr.Video(label="Stylized Result", interactive=False)
with gr.Row():
with gr.Column(scale=1) as control_col:
prompt = gr.Dropdown(
label="Style Instruction",
choices=STYLE_CHOICES,
value=STYLE_CHOICES[0],
allow_custom_value=True,
)
with gr.Tabs():
with gr.Tab("Quick Generate"):
with gr.Row():
fast_btn = gr.Button(
"⚡ Fast Generate", variant="primary"
)
balanced_btn = gr.Button(
"🎯 Balanced Generate", variant="primary"
)
quality_btn = gr.Button(
"🌟 Quality Generate", variant="primary"
)
_bind_preset_button(
button=fast_btn,
preset_key="Fast",
inputs=[
input_video,
prompt,
],
output=output_video,
extra_kwargs={"seed": None, "randomize_seed": True},
)
_bind_preset_button(
button=balanced_btn,
preset_key="Balanced",
inputs=[
input_video,
prompt,
],
output=output_video,
extra_kwargs={"seed": None, "randomize_seed": True},
)
_bind_preset_button(
button=quality_btn,
preset_key="Quality",
inputs=[
input_video,
prompt,
],
output=output_video,
extra_kwargs={"seed": None, "randomize_seed": True},
)
with gr.Tab("Advanced Settings"):
with gr.Row():
shift_gamma = _create_slider(GAMMA_SLIDER)
guidance_scale = _create_slider(GUIDANCE_SLIDER)
with gr.Row():
num_steps = _create_slider(STEP_SLIDER)
randomize_seed_adv = gr.Checkbox(
label="Randomize Seed",
value=True,
info="Checked = new random seed each run. Uncheck to provide your own seed.",
)
seed_adv = gr.Number(
label="Seed (used when Randomize is off)",
value=42,
precision=0,
)
run_btn = gr.Button("Generate", variant="primary")
run_btn.click(
fn=run_stylization,
inputs=[
input_video,
prompt,
shift_gamma,
num_steps,
guidance_scale,
seed_adv,
randomize_seed_adv,
],
outputs=[output_video],
)
with gr.Column(scale=1):
gr.Examples(
examples=EXAMPLE_INPUTS,
inputs=[input_video, prompt],
label="Example inputs",
)
return demo
demo = build_demo()
if __name__ == "__main__":
demo.launch()