ViTeX-14B / inference_example.py
ViTeX-Bench's picture
Use single-word prompt example
4932bf3 verified
"""
ViTeX-14B inference example (self-contained).
Assumes you cloned this HuggingFace repo and are running this script from the
repo root. The bundled `diffsynth/` library, `vitex_14b.safetensors` weights,
and the full `base_model/` directory are picked up automatically.
Usage:
python inference_example.py \
--vace_video path/to/source.mp4 \
--vace_mask path/to/mask.mp4 \
--glyph_video path/to/target_glyph.mp4 \
--prompt "HILTON" \
--output out.mp4
Hardware:
- 1 × NVIDIA GPU with >= 80 GB VRAM (peak ~70 GB at 720 × 1280 × 121 frames)
- ~250 GB CPU RAM recommended (DiT loading + activation offload)
"""
import os
import sys
import argparse
import glob
# Use the bundled diffsynth shipped alongside this script.
HERE = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, HERE)
import torch
from PIL import Image
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
from diffsynth.core import load_state_dict
BASE_DIR = os.path.join(HERE, "base_model")
ADAPTER_CKPT = os.path.join(HERE, "vitex_14b.safetensors")
TOKENIZER_DIR = os.path.join(BASE_DIR, "google", "umt5-xxl")
HEIGHT = 720
WIDTH = 1280
NUM_FRAMES = 121
NUM_INFERENCE_STEPS = 50
CFG_SCALE = 5.0
SEED = 42
def load_video_frames(path, target_frames=NUM_FRAMES, resize=(HEIGHT, WIDTH)):
"""Load a video file into a list of PIL Images, sub-sampled or padded to
`target_frames`, optionally resized to `(H, W)`."""
import cv2
cap = cv2.VideoCapture(path)
frames = []
while True:
ok, frame = cap.read()
if not ok:
break
img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
if resize:
img = img.resize((resize[1], resize[0]), Image.LANCZOS) # (W, H)
frames.append(img)
cap.release()
if not frames:
raise ValueError(f"empty video: {path}")
if target_frames and len(frames) > target_frames:
import numpy as np
idx = np.linspace(0, len(frames) - 1, target_frames, dtype=int)
frames = [frames[i] for i in idx]
elif target_frames and len(frames) < target_frames:
frames.extend([frames[-1]] * (target_frames - len(frames)))
return frames
def save_video(frames, path, fps=24):
"""Save list of PIL Images to an H.264 MP4."""
import subprocess, numpy as np
import imageio_ffmpeg
ffmpeg = imageio_ffmpeg.get_ffmpeg_exe()
w, h = frames[0].size
cmd = [
ffmpeg, "-y",
"-f", "rawvideo", "-vcodec", "rawvideo",
"-s", f"{w}x{h}", "-pix_fmt", "rgb24",
"-r", str(fps),
"-i", "-",
"-c:v", "libx264", "-preset", "fast", "-crf", "18",
"-pix_fmt", "yuv420p",
path,
]
proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stderr=subprocess.DEVNULL)
for fr in frames:
proc.stdin.write(np.array(fr).tobytes())
proc.stdin.close()
proc.wait()
def build_pipeline(device="cuda:0"):
diffusion_shards = sorted(glob.glob(os.path.join(BASE_DIR, "diffusion_pytorch_model-*.safetensors")))
if not diffusion_shards:
raise FileNotFoundError(
f"No diffusion_pytorch_model-*.safetensors found under {BASE_DIR}. "
"Make sure you downloaded the full repo via `git lfs clone` or "
"`huggingface-cli download ViTeX-Bench/ViTeX-14B`."
)
if not os.path.isfile(ADAPTER_CKPT):
raise FileNotFoundError(f"Missing trained adapter: {ADAPTER_CKPT}")
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device=device,
model_configs=[
ModelConfig(path=diffusion_shards),
ModelConfig(path=os.path.join(BASE_DIR, "models_t5_umt5-xxl-enc-bf16.pth")),
ModelConfig(path=os.path.join(BASE_DIR, "Wan2.1_VAE.pth")),
],
tokenizer_config=ModelConfig(path=TOKENIZER_DIR),
redirect_common_files=False,
)
print(f"Loading ViTeX-14B trained weights from {ADAPTER_CKPT}")
state = load_state_dict(ADAPTER_CKPT)
res = pipe.vace.load_state_dict(state, strict=False)
print(f" loaded {len(state)} keys (missing {len(res.missing_keys)}, unexpected {len(res.unexpected_keys)})")
del state
return pipe
def main():
p = argparse.ArgumentParser()
p.add_argument("--vace_video", required=True, help="Source RGB video to edit.")
p.add_argument("--vace_mask", required=True, help="Per-frame binary mask: 1=replace, 0=keep.")
p.add_argument("--glyph_video", required=True, help="Pre-rendered target glyphs in the mask region.")
p.add_argument("--prompt", default="", help="Optional text prompt describing the edit.")
p.add_argument("--output", default="output.mp4")
p.add_argument("--height", type=int, default=HEIGHT)
p.add_argument("--width", type=int, default=WIDTH)
p.add_argument("--num_frames", type=int, default=NUM_FRAMES)
p.add_argument("--num_inference_steps", type=int, default=NUM_INFERENCE_STEPS)
p.add_argument("--cfg_scale", type=float, default=CFG_SCALE)
p.add_argument("--seed", type=int, default=SEED)
p.add_argument("--device", default="cuda:0")
args = p.parse_args()
pipe = build_pipeline(device=args.device)
target_size = (args.height, args.width)
vace_video = load_video_frames(args.vace_video, args.num_frames, target_size)
vace_mask = load_video_frames(args.vace_mask, args.num_frames, target_size)
glyph = load_video_frames(args.glyph_video, args.num_frames, target_size)
print(f"Running pipeline (seed={args.seed}, cfg={args.cfg_scale}, steps={args.num_inference_steps})...")
out_frames = pipe(
prompt=args.prompt,
negative_prompt="",
vace_video=vace_video,
vace_video_mask=vace_mask,
glyph_video=glyph,
seed=args.seed,
height=args.height,
width=args.width,
num_frames=args.num_frames,
cfg_scale=args.cfg_scale,
num_inference_steps=args.num_inference_steps,
tiled=True,
)
save_video(out_frames, args.output)
print(f"saved: {args.output}")
if __name__ == "__main__":
main()