ViTeX-Bench commited on
Commit
7b0e7d8
·
verified ·
1 Parent(s): 151ad29

Add inference_example.py

Browse files
Files changed (1) hide show
  1. inference_example.py +171 -0
inference_example.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ViTeX-14B inference example.
3
+
4
+ Loads:
5
+ - Wan-AI/Wan2.1-VACE-14B (base model)
6
+ - ViTeX-Bench/ViTeX-14B (this fine-tuned VACE module)
7
+
8
+ Runs one or more video text-edit jobs, writing MP4 outputs.
9
+
10
+ Requires:
11
+ - The DiffSynth-Studio-TextVACE fork (provides GlyphEncoder + ConditionCrossAttention)
12
+ - torch >= 2.7.0+cu128 (NCCL >= 2.25.1 recommended on H100)
13
+ - One NVIDIA GPU with >= 80 GB VRAM (H100 / A100 80 GB)
14
+ - imageio-ffmpeg, opencv-python
15
+
16
+ Usage:
17
+ python inference_example.py \
18
+ --vace_video path/to/source.mp4 \
19
+ --vace_mask path/to/mask.mp4 \
20
+ --glyph_video path/to/target_glyph.mp4 \
21
+ --prompt "Change the sign to read 'HILTON'" \
22
+ --output out.mp4
23
+ """
24
+
25
+ import os
26
+ import argparse
27
+ import glob
28
+
29
+ import torch
30
+ from PIL import Image
31
+
32
+ from huggingface_hub import snapshot_download
33
+
34
+ from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
35
+ from diffsynth.core import load_state_dict
36
+
37
+
38
+ HEIGHT = 720
39
+ WIDTH = 1280
40
+ NUM_FRAMES = 121
41
+ NUM_INFERENCE_STEPS = 50
42
+ CFG_SCALE = 5.0
43
+ SEED = 42
44
+
45
+
46
+ def load_video_frames(path, target_frames=NUM_FRAMES, resize=(HEIGHT, WIDTH)):
47
+ """Load a video file into a list of PIL Images, optionally subsampling/padding."""
48
+ import cv2
49
+ cap = cv2.VideoCapture(path)
50
+ frames = []
51
+ while True:
52
+ ok, frame = cap.read()
53
+ if not ok:
54
+ break
55
+ img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
56
+ if resize:
57
+ img = img.resize((resize[1], resize[0]), Image.LANCZOS) # (W, H)
58
+ frames.append(img)
59
+ cap.release()
60
+
61
+ if not frames:
62
+ raise ValueError(f"empty video: {path}")
63
+
64
+ if target_frames and len(frames) > target_frames:
65
+ import numpy as np
66
+ idx = np.linspace(0, len(frames) - 1, target_frames, dtype=int)
67
+ frames = [frames[i] for i in idx]
68
+ elif target_frames and len(frames) < target_frames:
69
+ frames.extend([frames[-1]] * (target_frames - len(frames)))
70
+ return frames
71
+
72
+
73
+ def save_video(frames, path, fps=24):
74
+ """Save list of PIL Images to an H.264 MP4."""
75
+ import subprocess, numpy as np
76
+ import imageio_ffmpeg
77
+ ffmpeg = imageio_ffmpeg.get_ffmpeg_exe()
78
+ w, h = frames[0].size
79
+ cmd = [
80
+ ffmpeg, "-y",
81
+ "-f", "rawvideo", "-vcodec", "rawvideo",
82
+ "-s", f"{w}x{h}", "-pix_fmt", "rgb24",
83
+ "-r", str(fps),
84
+ "-i", "-",
85
+ "-c:v", "libx264", "-preset", "fast", "-crf", "18",
86
+ "-pix_fmt", "yuv420p",
87
+ path,
88
+ ]
89
+ proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stderr=subprocess.DEVNULL)
90
+ for fr in frames:
91
+ proc.stdin.write(np.array(fr).tobytes())
92
+ proc.stdin.close()
93
+ proc.wait()
94
+
95
+
96
+ def build_pipeline(base_dir, ckpt_path, device="cuda:0"):
97
+ diffusion_shards = sorted(glob.glob(os.path.join(base_dir, "diffusion_pytorch_model-*.safetensors")))
98
+ pipe = WanVideoPipeline.from_pretrained(
99
+ torch_dtype=torch.bfloat16,
100
+ device=device,
101
+ model_configs=[
102
+ ModelConfig(path=diffusion_shards),
103
+ ModelConfig(path=os.path.join(base_dir, "models_t5_umt5-xxl-enc-bf16.pth")),
104
+ ModelConfig(path=os.path.join(base_dir, "Wan2.1_VAE.pth")),
105
+ ],
106
+ tokenizer_config=ModelConfig(path=os.path.join(base_dir, "google/umt5-xxl")),
107
+ redirect_common_files=False,
108
+ )
109
+ print(f"Loading ViTeX-14B weights from {ckpt_path}")
110
+ state = load_state_dict(ckpt_path)
111
+ res = pipe.vace.load_state_dict(state, strict=False)
112
+ print(f" loaded {len(state)} keys (missing {len(res.missing_keys)}, unexpected {len(res.unexpected_keys)})")
113
+ del state
114
+ return pipe
115
+
116
+
117
+ def main():
118
+ p = argparse.ArgumentParser()
119
+ p.add_argument("--vace_video", required=True, help="Source RGB video (the one to edit).")
120
+ p.add_argument("--vace_mask", required=True, help="Per-frame binary mask: 1=replace, 0=keep.")
121
+ p.add_argument("--glyph_video", required=True, help="Pre-rendered target glyphs placed in the mask region.")
122
+ p.add_argument("--prompt", default="", help="Optional text prompt describing the edit.")
123
+ p.add_argument("--output", default="output.mp4")
124
+ p.add_argument("--height", type=int, default=HEIGHT)
125
+ p.add_argument("--width", type=int, default=WIDTH)
126
+ p.add_argument("--num_frames", type=int, default=NUM_FRAMES)
127
+ p.add_argument("--num_inference_steps", type=int, default=NUM_INFERENCE_STEPS)
128
+ p.add_argument("--cfg_scale", type=float, default=CFG_SCALE)
129
+ p.add_argument("--seed", type=int, default=SEED)
130
+ p.add_argument("--device", default="cuda:0")
131
+ args = p.parse_args()
132
+
133
+ # 1. Download base + this model
134
+ print("Downloading Wan-AI/Wan2.1-VACE-14B (base, ~60 GB)...")
135
+ base_dir = snapshot_download("Wan-AI/Wan2.1-VACE-14B")
136
+ print("Downloading ViTeX-Bench/ViTeX-14B (this model, ~8 GB)...")
137
+ vitex_dir = snapshot_download("ViTeX-Bench/ViTeX-14B")
138
+ ckpt_path = os.path.join(vitex_dir, "vitex_14b.safetensors")
139
+
140
+ # 2. Build pipeline
141
+ pipe = build_pipeline(base_dir, ckpt_path, device=args.device)
142
+
143
+ # 3. Load inputs
144
+ target_size = (args.height, args.width)
145
+ vace_video = load_video_frames(args.vace_video, args.num_frames, target_size)
146
+ vace_mask = load_video_frames(args.vace_mask, args.num_frames, target_size)
147
+ glyph = load_video_frames(args.glyph_video, args.num_frames, target_size)
148
+
149
+ # 4. Run
150
+ print(f"Running pipeline (seed={args.seed}, cfg={args.cfg_scale}, steps={args.num_inference_steps})...")
151
+ out_frames = pipe(
152
+ prompt=args.prompt,
153
+ negative_prompt="",
154
+ vace_video=vace_video,
155
+ vace_video_mask=vace_mask,
156
+ glyph_video=glyph,
157
+ seed=args.seed,
158
+ height=args.height,
159
+ width=args.width,
160
+ num_frames=args.num_frames,
161
+ cfg_scale=args.cfg_scale,
162
+ num_inference_steps=args.num_inference_steps,
163
+ tiled=True,
164
+ )
165
+
166
+ save_video(out_frames, args.output)
167
+ print(f"saved: {args.output}")
168
+
169
+
170
+ if __name__ == "__main__":
171
+ main()