File size: 6,200 Bytes
7b0e7d8 9b9fe26 7b0e7d8 baabe40 7b0e7d8 4932bf3 7b0e7d8 baabe40 7b0e7d8 baabe40 7b0e7d8 baabe40 7b0e7d8 baabe40 7b0e7d8 baabe40 7b0e7d8 baabe40 74f6150 baabe40 7b0e7d8 baabe40 7b0e7d8 baabe40 7b0e7d8 baabe40 9b9fe26 baabe40 7b0e7d8 baabe40 7b0e7d8 baabe40 7b0e7d8 baabe40 7b0e7d8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 | """
ViTeX-Edit-14B inference example (self-contained).
Assumes you cloned this HuggingFace repo and are running this script from the
repo root. The bundled `diffsynth/` library, `vitex_14b.safetensors` weights,
and the full `base_model/` directory are picked up automatically.
Usage:
python inference_example.py \
--vace_video path/to/source.mp4 \
--vace_mask path/to/mask.mp4 \
--glyph_video path/to/target_glyph.mp4 \
--prompt "HILTON" \
--output out.mp4
Hardware:
- 1 × NVIDIA GPU with >= 80 GB VRAM (peak ~70 GB at 720 × 1280 × 121 frames)
- ~250 GB CPU RAM recommended (DiT loading + activation offload)
"""
import os
import sys
import argparse
import glob
# Use the bundled diffsynth shipped alongside this script.
HERE = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, HERE)
import torch
from PIL import Image
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
from diffsynth.core import load_state_dict
BASE_DIR = os.path.join(HERE, "base_model")
ADAPTER_CKPT = os.path.join(HERE, "vitex_14b.safetensors")
TOKENIZER_DIR = os.path.join(BASE_DIR, "google", "umt5-xxl")
HEIGHT = 720
WIDTH = 1280
NUM_FRAMES = 121
NUM_INFERENCE_STEPS = 50
CFG_SCALE = 5.0
SEED = 42
def load_video_frames(path, target_frames=NUM_FRAMES, resize=(HEIGHT, WIDTH)):
"""Load a video file into a list of PIL Images, sub-sampled or padded to
`target_frames`, optionally resized to `(H, W)`."""
import cv2
cap = cv2.VideoCapture(path)
frames = []
while True:
ok, frame = cap.read()
if not ok:
break
img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
if resize:
img = img.resize((resize[1], resize[0]), Image.LANCZOS) # (W, H)
frames.append(img)
cap.release()
if not frames:
raise ValueError(f"empty video: {path}")
if target_frames and len(frames) > target_frames:
import numpy as np
idx = np.linspace(0, len(frames) - 1, target_frames, dtype=int)
frames = [frames[i] for i in idx]
elif target_frames and len(frames) < target_frames:
frames.extend([frames[-1]] * (target_frames - len(frames)))
return frames
def save_video(frames, path, fps=24):
"""Save list of PIL Images to an H.264 MP4."""
import subprocess, numpy as np
import imageio_ffmpeg
ffmpeg = imageio_ffmpeg.get_ffmpeg_exe()
w, h = frames[0].size
cmd = [
ffmpeg, "-y",
"-f", "rawvideo", "-vcodec", "rawvideo",
"-s", f"{w}x{h}", "-pix_fmt", "rgb24",
"-r", str(fps),
"-i", "-",
"-c:v", "libx264", "-preset", "fast", "-crf", "18",
"-pix_fmt", "yuv420p",
path,
]
proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stderr=subprocess.DEVNULL)
for fr in frames:
proc.stdin.write(np.array(fr).tobytes())
proc.stdin.close()
proc.wait()
def build_pipeline(device="cuda:0"):
diffusion_shards = sorted(glob.glob(os.path.join(BASE_DIR, "diffusion_pytorch_model-*.safetensors")))
if not diffusion_shards:
raise FileNotFoundError(
f"No diffusion_pytorch_model-*.safetensors found under {BASE_DIR}. "
"Make sure you downloaded the full repo via `git lfs clone` or "
"`huggingface-cli download ViTeX-Bench/ViTeX-Edit-14B`."
)
if not os.path.isfile(ADAPTER_CKPT):
raise FileNotFoundError(f"Missing trained adapter: {ADAPTER_CKPT}")
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device=device,
model_configs=[
ModelConfig(path=diffusion_shards),
ModelConfig(path=os.path.join(BASE_DIR, "models_t5_umt5-xxl-enc-bf16.pth")),
ModelConfig(path=os.path.join(BASE_DIR, "Wan2.1_VAE.pth")),
],
tokenizer_config=ModelConfig(path=TOKENIZER_DIR),
redirect_common_files=False,
)
print(f"Loading ViTeX-Edit-14B trained weights from {ADAPTER_CKPT}")
state = load_state_dict(ADAPTER_CKPT)
res = pipe.vace.load_state_dict(state, strict=False)
print(f" loaded {len(state)} keys (missing {len(res.missing_keys)}, unexpected {len(res.unexpected_keys)})")
del state
return pipe
def main():
p = argparse.ArgumentParser()
p.add_argument("--vace_video", required=True, help="Source RGB video to edit.")
p.add_argument("--vace_mask", required=True, help="Per-frame binary mask: 1=replace, 0=keep.")
p.add_argument("--glyph_video", required=True, help="Pre-rendered target glyphs in the mask region.")
p.add_argument("--prompt", default="", help="Optional text prompt describing the edit.")
p.add_argument("--output", default="output.mp4")
p.add_argument("--height", type=int, default=HEIGHT)
p.add_argument("--width", type=int, default=WIDTH)
p.add_argument("--num_frames", type=int, default=NUM_FRAMES)
p.add_argument("--num_inference_steps", type=int, default=NUM_INFERENCE_STEPS)
p.add_argument("--cfg_scale", type=float, default=CFG_SCALE)
p.add_argument("--seed", type=int, default=SEED)
p.add_argument("--device", default="cuda:0")
args = p.parse_args()
pipe = build_pipeline(device=args.device)
target_size = (args.height, args.width)
vace_video = load_video_frames(args.vace_video, args.num_frames, target_size)
vace_mask = load_video_frames(args.vace_mask, args.num_frames, target_size)
glyph = load_video_frames(args.glyph_video, args.num_frames, target_size)
print(f"Running pipeline (seed={args.seed}, cfg={args.cfg_scale}, steps={args.num_inference_steps})...")
out_frames = pipe(
prompt=args.prompt,
negative_prompt="",
vace_video=vace_video,
vace_video_mask=vace_mask,
glyph_video=glyph,
seed=args.seed,
height=args.height,
width=args.width,
num_frames=args.num_frames,
cfg_scale=args.cfg_scale,
num_inference_steps=args.num_inference_steps,
tiled=True,
)
save_video(out_frames, args.output)
print(f"saved: {args.output}")
if __name__ == "__main__":
main()
|