|
|
import os |
|
|
import sys |
|
|
import uuid |
|
|
from pathlib import Path |
|
|
from hydra import compose, initialize |
|
|
from omegaconf import OmegaConf |
|
|
from PIL import Image |
|
|
import gradio as gr |
|
|
import torch |
|
|
import numpy as np |
|
|
from torchvision import transforms |
|
|
from einops import rearrange |
|
|
from huggingface_hub import hf_hub_download |
|
|
import spaces |
|
|
|
|
|
sys.path.append(str(Path(__file__).resolve().parent.parent)) |
|
|
|
|
|
from algorithms.wan.wan_i2v import WanImageToVideo |
|
|
from utils.video_utils import numpy_to_mp4_bytes |
|
|
|
|
|
DEVICE = "cuda" |
|
|
|
|
|
|
|
|
|
|
|
def load_model() -> WanImageToVideo: |
|
|
print("Downloading model...") |
|
|
ckpt_path = hf_hub_download( |
|
|
repo_id="KempnerInstituteAI/LVP", |
|
|
filename="checkpoints/LVP_14B_inference.ckpt", |
|
|
cache_dir="./huggingface", |
|
|
) |
|
|
umt5_path = hf_hub_download( |
|
|
repo_id="Wan-AI/Wan2.1-I2V-14B-480P", |
|
|
filename="models_t5_umt5-xxl-enc-bf16.pth", |
|
|
cache_dir="./huggingface", |
|
|
) |
|
|
vae_path = hf_hub_download( |
|
|
repo_id="Wan-AI/Wan2.1-I2V-14B-480P", |
|
|
filename="Wan2.1_VAE.pth", |
|
|
cache_dir="./huggingface", |
|
|
) |
|
|
clip_path = hf_hub_download( |
|
|
repo_id="Wan-AI/Wan2.1-I2V-14B-480P", |
|
|
filename="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", |
|
|
cache_dir="./huggingface", |
|
|
) |
|
|
config_path = hf_hub_download( |
|
|
repo_id="Wan-AI/Wan2.1-I2V-14B-480P", |
|
|
filename="config.json", |
|
|
cache_dir="./huggingface/Wan2.1-I2V-14B-480P", |
|
|
) |
|
|
|
|
|
with initialize(version_base=None, config_path="./configurations"): |
|
|
cfg = compose( |
|
|
config_name="config", |
|
|
overrides=[ |
|
|
"experiment=exp_video", |
|
|
"algorithm=wan_i2v", |
|
|
"dataset=dummy", |
|
|
"experiment.tasks=[test]", |
|
|
"algorithm.sample_steps=40", |
|
|
"algorithm.load_prompt_embed=False", |
|
|
f"algorithm.model.tuned_ckpt_path={ckpt_path}", |
|
|
f"algorithm.text_encoder.ckpt_path={umt5_path}", |
|
|
f"algorithm.vae.ckpt_path={vae_path}", |
|
|
f"algorithm.clip.ckpt_path={clip_path}", |
|
|
f"algorithm.model.ckpt_path={Path(config_path).parent}", |
|
|
], |
|
|
) |
|
|
OmegaConf.resolve(cfg) |
|
|
cfg = cfg.algorithm |
|
|
print("Initializing model...") |
|
|
_model = WanImageToVideo(cfg) |
|
|
print("Configuring model...") |
|
|
_model.configure_model() |
|
|
_model = _model.eval().to(DEVICE) |
|
|
_model.vae_scale = [_model.vae_mean, _model.vae_inv_std] |
|
|
return _model |
|
|
|
|
|
|
|
|
def load_transform(height: int, width: int): |
|
|
return transforms.Compose( |
|
|
[ |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), |
|
|
transforms.RandomResizedCrop( |
|
|
size=(height, width), |
|
|
scale=(1.0, 1.0), |
|
|
ratio=(width / height, width / height), |
|
|
interpolation=transforms.InterpolationMode.BICUBIC, |
|
|
), |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
model = load_model() |
|
|
print("Model loaded successfully") |
|
|
transform = load_transform(model.height, model.width) |
|
|
|
|
|
def get_duration(image: str, prompt: str, sample_steps: int, lang_guidance: float, hist_guidance: float, progress: gr.Progress) -> int: |
|
|
step_duration = 5 |
|
|
multiplier = 1 + int(lang_guidance > 0) + int(hist_guidance > 0) - int(lang_guidance == hist_guidance and lang_guidance > 0) |
|
|
return int(20 + sample_steps * multiplier * step_duration) |
|
|
|
|
|
@spaces.GPU(duration=get_duration) |
|
|
@torch.no_grad() |
|
|
@torch.autocast(DEVICE, dtype=torch.bfloat16) |
|
|
def infer_i2v( |
|
|
image: str, |
|
|
prompt: str, |
|
|
sample_steps: int, |
|
|
lang_guidance: float, |
|
|
hist_guidance: float, |
|
|
progress: gr.Progress = gr.Progress(), |
|
|
) -> str: |
|
|
"""Run I2V inference, given an image path, prompt, and sampling parameters.""" |
|
|
image = transform(Image.open(image).convert("RGB")) |
|
|
videos = torch.randn(1, model.n_frames, 3, model.height, model.width, device=DEVICE) |
|
|
videos[:, 0] = image[None] |
|
|
batch = { |
|
|
"videos": videos, |
|
|
"prompts": [prompt], |
|
|
"has_bbox": torch.zeros(1, 2, device=DEVICE).bool(), |
|
|
"bbox_render": torch.zeros(1, 2, model.height, model.width, device=DEVICE), |
|
|
} |
|
|
model.hist_guidance = hist_guidance |
|
|
model.lang_guidance = lang_guidance |
|
|
model.sample_steps = sample_steps |
|
|
pbar = progress.tqdm(range(sample_steps), desc="Sampling") |
|
|
video = rearrange( |
|
|
model.sample_seq(batch, pbar=pbar).squeeze(0), "t c h w -> t h w c" |
|
|
) |
|
|
video = video.squeeze(0).float().cpu().numpy() |
|
|
video = np.clip(video * 0.5 + 0.5, 0, 1) |
|
|
video = (video * 255).astype(np.uint8) |
|
|
video_bytes = numpy_to_mp4_bytes(video, fps=model.cfg.logging.fps) |
|
|
videos_dir = Path("./videos") |
|
|
videos_dir.mkdir(exist_ok=True) |
|
|
video_path = videos_dir / f"{uuid.uuid4()}.mp4" |
|
|
with open(video_path, "wb") as f: |
|
|
f.write(video_bytes) |
|
|
return video_path.as_posix() |
|
|
|
|
|
examples_dir = Path("examples") |
|
|
examples = [] |
|
|
if examples_dir.exists(): |
|
|
for image_path in sorted(examples_dir.iterdir()): |
|
|
if not image_path.is_file(): |
|
|
continue |
|
|
examples.append([image_path.as_posix(), image_path.stem[2:].replace("_", " ")]) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
with gr.Blocks() as demo: |
|
|
gr.HTML( |
|
|
""" |
|
|
<style> |
|
|
.header-button-row { |
|
|
gap: 4px !important; |
|
|
} |
|
|
.header-button-row div { |
|
|
width: 131.0px !important; |
|
|
} |
|
|
.header-button-column { |
|
|
width: 131.0px !important; |
|
|
gap: 5px !important; |
|
|
} |
|
|
.header-button a { |
|
|
border: 1px solid #e4e4e7; |
|
|
} |
|
|
.header-button .button-icon { |
|
|
margin-right: 8px; |
|
|
} |
|
|
#sample-gallery table { |
|
|
width: 100% !important; |
|
|
} |
|
|
#sample-gallery td:first-child { |
|
|
width: 25% !important; |
|
|
} |
|
|
#sample-gallery .border.table, |
|
|
#sample-gallery .container.table, |
|
|
#sample-gallery .container { |
|
|
max-height: none !important; |
|
|
height: auto !important; |
|
|
max-width: none !important; |
|
|
width: 100% !important; |
|
|
} |
|
|
#sample-gallery img { |
|
|
width: 100% !important; |
|
|
height: auto !important; |
|
|
object-fit: contain !important; |
|
|
} |
|
|
</style> |
|
|
""" |
|
|
) |
|
|
with gr.Sidebar(): |
|
|
gr.Markdown("# Large Video Planner") |
|
|
gr.Markdown( |
|
|
"### Official Interactive Demo for [_Large Video Planner Enables Generalizable Robot Control_](todo)" |
|
|
) |
|
|
gr.Markdown("---") |
|
|
gr.Markdown("#### Links ↓") |
|
|
with gr.Row(elem_classes=["header-button-row"]): |
|
|
with gr.Column(elem_classes=["header-button-column"], min_width=0): |
|
|
gr.Button( |
|
|
value="Website", |
|
|
link="https://www.boyuan.space/large-video-planner/", |
|
|
icon="https://simpleicons.org/icons/googlechrome.svg", |
|
|
elem_classes=["header-button"], |
|
|
size="md", |
|
|
min_width=0, |
|
|
) |
|
|
gr.Button( |
|
|
value="Paper", |
|
|
link="todo", |
|
|
icon="https://simpleicons.org/icons/arxiv.svg", |
|
|
elem_classes=["header-button"], |
|
|
size="md", |
|
|
min_width=0, |
|
|
) |
|
|
with gr.Column(elem_classes=["header-button-column"], min_width=0): |
|
|
gr.Button( |
|
|
value="Code", |
|
|
link="https://github.com/buoyancy99/large-video-planner", |
|
|
icon="https://simpleicons.org/icons/github.svg", |
|
|
elem_classes=["header-button"], |
|
|
size="md", |
|
|
min_width=0, |
|
|
) |
|
|
gr.Button( |
|
|
value="Weights", |
|
|
link="https://huggingface.co/large-video-planner/LVP", |
|
|
icon="https://simpleicons.org/icons/huggingface.svg", |
|
|
elem_classes=["header-button"], |
|
|
size="md", |
|
|
min_width=0, |
|
|
) |
|
|
gr.Markdown("---") |
|
|
gr.Markdown("#### Troubleshooting ↓") |
|
|
with gr.Group(): |
|
|
with gr.Accordion("Error or Unexpected Results?", open=False): |
|
|
gr.Markdown("Please try again after refreshing the page and ensure you do not click the same button multiple times.") |
|
|
with gr.Accordion("Too Slow or No GPU Allocation?", open=False): |
|
|
gr.Markdown( |
|
|
"This demo may respond slowly because it runs a large, non-distilled model. Consider running the demo locally (click the dots in the top-right corner). Alternatively, you can subscribe to Hugging Face Pro for an increased GPU quota." |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
image_input = gr.Image(label="Input Image", type="filepath") |
|
|
prompt_input = gr.Textbox(label="Prompt", lines=2, max_lines=2) |
|
|
with gr.Column(): |
|
|
sample_steps_slider = gr.Slider( |
|
|
label="Sampling Steps", |
|
|
minimum=10, |
|
|
maximum=50, |
|
|
value=30, |
|
|
step=1, |
|
|
) |
|
|
lang_guidance_slider = gr.Slider( |
|
|
label="Language Guidance (recommended 1.5-2.5)", |
|
|
minimum=0, |
|
|
maximum=5, |
|
|
value=2.5, |
|
|
step=0.1, |
|
|
) |
|
|
hist_guidance_slider = gr.Slider( |
|
|
label="History Guidance (recommended 1.0-2.0)", |
|
|
minimum=0, |
|
|
maximum=5, |
|
|
value=1.5, |
|
|
step=0.1, |
|
|
) |
|
|
run_button = gr.Button("Generate Video") |
|
|
with gr.Column(): |
|
|
video_output = gr.Video(label="Generated Video") |
|
|
|
|
|
gr.Examples( |
|
|
examples=examples, |
|
|
inputs=[image_input, prompt_input], |
|
|
outputs=[video_output], |
|
|
run_on_click=False, |
|
|
elem_id="sample-gallery", |
|
|
) |
|
|
|
|
|
run_button.click( |
|
|
fn=infer_i2v, |
|
|
inputs=[ |
|
|
image_input, |
|
|
prompt_input, |
|
|
sample_steps_slider, |
|
|
lang_guidance_slider, |
|
|
hist_guidance_slider, |
|
|
], |
|
|
outputs=video_output, |
|
|
) |
|
|
|
|
|
demo.launch(share=True) |
|
|
|