| """ |
| ViTeX-14B inference example (self-contained). |
| |
| Assumes you cloned this HuggingFace repo and are running this script from the |
| repo root. The bundled `diffsynth/` library, `vitex_14b.safetensors` weights, |
| and the full `base_model/` directory are picked up automatically. |
| |
| Usage: |
| python inference_example.py \ |
| --vace_video path/to/source.mp4 \ |
| --vace_mask path/to/mask.mp4 \ |
| --glyph_video path/to/target_glyph.mp4 \ |
| --prompt "HILTON" \ |
| --output out.mp4 |
| |
| Hardware: |
| - 1 × NVIDIA GPU with >= 80 GB VRAM (peak ~70 GB at 720 × 1280 × 121 frames) |
| - ~250 GB CPU RAM recommended (DiT loading + activation offload) |
| """ |
|
|
| import os |
| import sys |
| import argparse |
| import glob |
|
|
| |
| HERE = os.path.dirname(os.path.abspath(__file__)) |
| sys.path.insert(0, HERE) |
|
|
| import torch |
| from PIL import Image |
|
|
| from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig |
| from diffsynth.core import load_state_dict |
|
|
|
|
| BASE_DIR = os.path.join(HERE, "base_model") |
| ADAPTER_CKPT = os.path.join(HERE, "vitex_14b.safetensors") |
| TOKENIZER_DIR = os.path.join(BASE_DIR, "google", "umt5-xxl") |
|
|
| HEIGHT = 720 |
| WIDTH = 1280 |
| NUM_FRAMES = 121 |
| NUM_INFERENCE_STEPS = 50 |
| CFG_SCALE = 5.0 |
| SEED = 42 |
|
|
|
|
| def load_video_frames(path, target_frames=NUM_FRAMES, resize=(HEIGHT, WIDTH)): |
| """Load a video file into a list of PIL Images, sub-sampled or padded to |
| `target_frames`, optionally resized to `(H, W)`.""" |
| import cv2 |
| cap = cv2.VideoCapture(path) |
| frames = [] |
| while True: |
| ok, frame = cap.read() |
| if not ok: |
| break |
| img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) |
| if resize: |
| img = img.resize((resize[1], resize[0]), Image.LANCZOS) |
| frames.append(img) |
| cap.release() |
|
|
| if not frames: |
| raise ValueError(f"empty video: {path}") |
| if target_frames and len(frames) > target_frames: |
| import numpy as np |
| idx = np.linspace(0, len(frames) - 1, target_frames, dtype=int) |
| frames = [frames[i] for i in idx] |
| elif target_frames and len(frames) < target_frames: |
| frames.extend([frames[-1]] * (target_frames - len(frames))) |
| return frames |
|
|
|
|
| def save_video(frames, path, fps=24): |
| """Save list of PIL Images to an H.264 MP4.""" |
| import subprocess, numpy as np |
| import imageio_ffmpeg |
| ffmpeg = imageio_ffmpeg.get_ffmpeg_exe() |
| w, h = frames[0].size |
| cmd = [ |
| ffmpeg, "-y", |
| "-f", "rawvideo", "-vcodec", "rawvideo", |
| "-s", f"{w}x{h}", "-pix_fmt", "rgb24", |
| "-r", str(fps), |
| "-i", "-", |
| "-c:v", "libx264", "-preset", "fast", "-crf", "18", |
| "-pix_fmt", "yuv420p", |
| path, |
| ] |
| proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stderr=subprocess.DEVNULL) |
| for fr in frames: |
| proc.stdin.write(np.array(fr).tobytes()) |
| proc.stdin.close() |
| proc.wait() |
|
|
|
|
| def build_pipeline(device="cuda:0"): |
| diffusion_shards = sorted(glob.glob(os.path.join(BASE_DIR, "diffusion_pytorch_model-*.safetensors"))) |
| if not diffusion_shards: |
| raise FileNotFoundError( |
| f"No diffusion_pytorch_model-*.safetensors found under {BASE_DIR}. " |
| "Make sure you downloaded the full repo via `git lfs clone` or " |
| "`huggingface-cli download ViTeX-Bench/ViTeX-14B`." |
| ) |
| if not os.path.isfile(ADAPTER_CKPT): |
| raise FileNotFoundError(f"Missing trained adapter: {ADAPTER_CKPT}") |
|
|
| pipe = WanVideoPipeline.from_pretrained( |
| torch_dtype=torch.bfloat16, |
| device=device, |
| model_configs=[ |
| ModelConfig(path=diffusion_shards), |
| ModelConfig(path=os.path.join(BASE_DIR, "models_t5_umt5-xxl-enc-bf16.pth")), |
| ModelConfig(path=os.path.join(BASE_DIR, "Wan2.1_VAE.pth")), |
| ], |
| tokenizer_config=ModelConfig(path=TOKENIZER_DIR), |
| redirect_common_files=False, |
| ) |
|
|
| print(f"Loading ViTeX-14B trained weights from {ADAPTER_CKPT}") |
| state = load_state_dict(ADAPTER_CKPT) |
| res = pipe.vace.load_state_dict(state, strict=False) |
| print(f" loaded {len(state)} keys (missing {len(res.missing_keys)}, unexpected {len(res.unexpected_keys)})") |
| del state |
| return pipe |
|
|
|
|
| def main(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--vace_video", required=True, help="Source RGB video to edit.") |
| p.add_argument("--vace_mask", required=True, help="Per-frame binary mask: 1=replace, 0=keep.") |
| p.add_argument("--glyph_video", required=True, help="Pre-rendered target glyphs in the mask region.") |
| p.add_argument("--prompt", default="", help="Optional text prompt describing the edit.") |
| p.add_argument("--output", default="output.mp4") |
| p.add_argument("--height", type=int, default=HEIGHT) |
| p.add_argument("--width", type=int, default=WIDTH) |
| p.add_argument("--num_frames", type=int, default=NUM_FRAMES) |
| p.add_argument("--num_inference_steps", type=int, default=NUM_INFERENCE_STEPS) |
| p.add_argument("--cfg_scale", type=float, default=CFG_SCALE) |
| p.add_argument("--seed", type=int, default=SEED) |
| p.add_argument("--device", default="cuda:0") |
| args = p.parse_args() |
|
|
| pipe = build_pipeline(device=args.device) |
|
|
| target_size = (args.height, args.width) |
| vace_video = load_video_frames(args.vace_video, args.num_frames, target_size) |
| vace_mask = load_video_frames(args.vace_mask, args.num_frames, target_size) |
| glyph = load_video_frames(args.glyph_video, args.num_frames, target_size) |
|
|
| print(f"Running pipeline (seed={args.seed}, cfg={args.cfg_scale}, steps={args.num_inference_steps})...") |
| out_frames = pipe( |
| prompt=args.prompt, |
| negative_prompt="", |
| vace_video=vace_video, |
| vace_video_mask=vace_mask, |
| glyph_video=glyph, |
| seed=args.seed, |
| height=args.height, |
| width=args.width, |
| num_frames=args.num_frames, |
| cfg_scale=args.cfg_scale, |
| num_inference_steps=args.num_inference_steps, |
| tiled=True, |
| ) |
|
|
| save_video(out_frames, args.output) |
| print(f"saved: {args.output}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|