File size: 6,200 Bytes
7b0e7d8
9b9fe26
7b0e7d8
baabe40
 
 
7b0e7d8
 
 
 
 
 
4932bf3
7b0e7d8
baabe40
 
 
 
7b0e7d8
 
 
baabe40
7b0e7d8
 
 
baabe40
 
 
 
7b0e7d8
 
 
 
 
 
 
baabe40
 
 
 
7b0e7d8
 
 
 
 
 
 
 
 
baabe40
 
7b0e7d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
baabe40
 
 
 
 
 
74f6150
baabe40
 
 
 
7b0e7d8
 
 
 
 
baabe40
 
7b0e7d8
baabe40
7b0e7d8
 
baabe40
9b9fe26
baabe40
7b0e7d8
 
 
 
 
 
 
 
baabe40
7b0e7d8
baabe40
7b0e7d8
 
 
 
 
 
 
 
 
 
 
baabe40
7b0e7d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
"""
ViTeX-Edit-14B inference example (self-contained).

Assumes you cloned this HuggingFace repo and are running this script from the
repo root. The bundled `diffsynth/` library, `vitex_14b.safetensors` weights,
and the full `base_model/` directory are picked up automatically.

Usage:
  python inference_example.py \
      --vace_video   path/to/source.mp4 \
      --vace_mask    path/to/mask.mp4 \
      --glyph_video  path/to/target_glyph.mp4 \
      --prompt       "HILTON" \
      --output       out.mp4

Hardware:
  - 1 × NVIDIA GPU with >= 80 GB VRAM (peak ~70 GB at 720 × 1280 × 121 frames)
  - ~250 GB CPU RAM recommended (DiT loading + activation offload)
"""

import os
import sys
import argparse
import glob

# Use the bundled diffsynth shipped alongside this script.
HERE = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, HERE)

import torch
from PIL import Image

from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
from diffsynth.core import load_state_dict


BASE_DIR    = os.path.join(HERE, "base_model")
ADAPTER_CKPT = os.path.join(HERE, "vitex_14b.safetensors")
TOKENIZER_DIR = os.path.join(BASE_DIR, "google", "umt5-xxl")

HEIGHT = 720
WIDTH  = 1280
NUM_FRAMES = 121
NUM_INFERENCE_STEPS = 50
CFG_SCALE = 5.0
SEED = 42


def load_video_frames(path, target_frames=NUM_FRAMES, resize=(HEIGHT, WIDTH)):
    """Load a video file into a list of PIL Images, sub-sampled or padded to
    `target_frames`, optionally resized to `(H, W)`."""
    import cv2
    cap = cv2.VideoCapture(path)
    frames = []
    while True:
        ok, frame = cap.read()
        if not ok:
            break
        img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
        if resize:
            img = img.resize((resize[1], resize[0]), Image.LANCZOS)  # (W, H)
        frames.append(img)
    cap.release()

    if not frames:
        raise ValueError(f"empty video: {path}")
    if target_frames and len(frames) > target_frames:
        import numpy as np
        idx = np.linspace(0, len(frames) - 1, target_frames, dtype=int)
        frames = [frames[i] for i in idx]
    elif target_frames and len(frames) < target_frames:
        frames.extend([frames[-1]] * (target_frames - len(frames)))
    return frames


def save_video(frames, path, fps=24):
    """Save list of PIL Images to an H.264 MP4."""
    import subprocess, numpy as np
    import imageio_ffmpeg
    ffmpeg = imageio_ffmpeg.get_ffmpeg_exe()
    w, h = frames[0].size
    cmd = [
        ffmpeg, "-y",
        "-f", "rawvideo", "-vcodec", "rawvideo",
        "-s", f"{w}x{h}", "-pix_fmt", "rgb24",
        "-r", str(fps),
        "-i", "-",
        "-c:v", "libx264", "-preset", "fast", "-crf", "18",
        "-pix_fmt", "yuv420p",
        path,
    ]
    proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stderr=subprocess.DEVNULL)
    for fr in frames:
        proc.stdin.write(np.array(fr).tobytes())
    proc.stdin.close()
    proc.wait()


def build_pipeline(device="cuda:0"):
    diffusion_shards = sorted(glob.glob(os.path.join(BASE_DIR, "diffusion_pytorch_model-*.safetensors")))
    if not diffusion_shards:
        raise FileNotFoundError(
            f"No diffusion_pytorch_model-*.safetensors found under {BASE_DIR}. "
            "Make sure you downloaded the full repo via `git lfs clone` or "
            "`huggingface-cli download ViTeX-Bench/ViTeX-Edit-14B`."
        )
    if not os.path.isfile(ADAPTER_CKPT):
        raise FileNotFoundError(f"Missing trained adapter: {ADAPTER_CKPT}")

    pipe = WanVideoPipeline.from_pretrained(
        torch_dtype=torch.bfloat16,
        device=device,
        model_configs=[
            ModelConfig(path=diffusion_shards),
            ModelConfig(path=os.path.join(BASE_DIR, "models_t5_umt5-xxl-enc-bf16.pth")),
            ModelConfig(path=os.path.join(BASE_DIR, "Wan2.1_VAE.pth")),
        ],
        tokenizer_config=ModelConfig(path=TOKENIZER_DIR),
        redirect_common_files=False,
    )

    print(f"Loading ViTeX-Edit-14B trained weights from {ADAPTER_CKPT}")
    state = load_state_dict(ADAPTER_CKPT)
    res = pipe.vace.load_state_dict(state, strict=False)
    print(f"  loaded {len(state)} keys (missing {len(res.missing_keys)}, unexpected {len(res.unexpected_keys)})")
    del state
    return pipe


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--vace_video",  required=True, help="Source RGB video to edit.")
    p.add_argument("--vace_mask",   required=True, help="Per-frame binary mask: 1=replace, 0=keep.")
    p.add_argument("--glyph_video", required=True, help="Pre-rendered target glyphs in the mask region.")
    p.add_argument("--prompt",      default="", help="Optional text prompt describing the edit.")
    p.add_argument("--output",      default="output.mp4")
    p.add_argument("--height", type=int, default=HEIGHT)
    p.add_argument("--width",  type=int, default=WIDTH)
    p.add_argument("--num_frames", type=int, default=NUM_FRAMES)
    p.add_argument("--num_inference_steps", type=int, default=NUM_INFERENCE_STEPS)
    p.add_argument("--cfg_scale", type=float, default=CFG_SCALE)
    p.add_argument("--seed", type=int, default=SEED)
    p.add_argument("--device", default="cuda:0")
    args = p.parse_args()

    pipe = build_pipeline(device=args.device)

    target_size = (args.height, args.width)
    vace_video = load_video_frames(args.vace_video,  args.num_frames, target_size)
    vace_mask  = load_video_frames(args.vace_mask,   args.num_frames, target_size)
    glyph      = load_video_frames(args.glyph_video, args.num_frames, target_size)

    print(f"Running pipeline (seed={args.seed}, cfg={args.cfg_scale}, steps={args.num_inference_steps})...")
    out_frames = pipe(
        prompt=args.prompt,
        negative_prompt="",
        vace_video=vace_video,
        vace_video_mask=vace_mask,
        glyph_video=glyph,
        seed=args.seed,
        height=args.height,
        width=args.width,
        num_frames=args.num_frames,
        cfg_scale=args.cfg_scale,
        num_inference_steps=args.num_inference_steps,
        tiled=True,
    )

    save_video(out_frames, args.output)
    print(f"saved: {args.output}")


if __name__ == "__main__":
    main()