Instructions to use ViTeX-Bench/ViTeX-Edit-14B with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use ViTeX-Bench/ViTeX-Edit-14B with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("ViTeX-Bench/ViTeX-Edit-14B", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
Add ViTeX-14B (Corp): training-free locality-preserving post-processing wrapper
Browse files`make_corp_baseline.py` composes a raw ViTeX-14B prediction's text region back
onto the source video with two per-frame operations:
1. Reinhard mean-variance LAB color matching on a 20-px band just outside
the mask, so the predicted glyphs match the source's local lighting.
2. Signed-distance feathered alpha compositing (4-px feather centered on
the mask boundary), so the boundary has no visible seam.
Inside the mask the result is the predicted glyphs (color-matched); outside
the feather the result is byte-identical to the source. SeqAcc / CharAcc /
TTS are within ~0.01 of raw ViTeX-14B (the predicted text region is
unchanged), while PSNR / SSIM / LPIPS / DreamSim jump to near-Identity.
CPU-only, ~5 min on 8 workers for the 157-clip evaluation split. Reference:
appendix G of the ViTeX-Bench paper.
- README.md +26 -1
- make_corp_baseline.py +187 -0
|
@@ -32,7 +32,8 @@ This repository is fully self-contained — it bundles the trained weights, the
|
|
| 32 |
.
|
| 33 |
├── README.md
|
| 34 |
├── requirements.txt
|
| 35 |
-
├── inference_example.py
|
|
|
|
| 36 |
├── vitex_14b.safetensors (8 GB — trained adapter weights)
|
| 37 |
├── diffsynth/ (bundled inference library)
|
| 38 |
└── base_model/ (70 GB — frozen base model files)
|
|
@@ -77,6 +78,30 @@ python inference_example.py \
|
|
| 77 |
|
| 78 |
The script automatically uses the bundled `base_model/` and `vitex_14b.safetensors` — no extra downloads.
|
| 79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
## Limitations
|
| 81 |
|
| 82 |
- Trained on 230 samples; coverage of artistic fonts, complex backgrounds, and non-Latin scripts is limited.
|
|
|
|
| 32 |
.
|
| 33 |
├── README.md
|
| 34 |
├── requirements.txt
|
| 35 |
+
├── inference_example.py run ViTeX-14B on one (video, mask, glyph) tuple
|
| 36 |
+
├── make_corp_baseline.py build the ViTeX-14B (Corp) variant from raw predictions
|
| 37 |
├── vitex_14b.safetensors (8 GB — trained adapter weights)
|
| 38 |
├── diffsynth/ (bundled inference library)
|
| 39 |
└── base_model/ (70 GB — frozen base model files)
|
|
|
|
| 78 |
|
| 79 |
The script automatically uses the bundled `base_model/` and `vitex_14b.safetensors` — no extra downloads.
|
| 80 |
|
| 81 |
+
## Locality-preserving variant: ViTeX-14B (Corp)
|
| 82 |
+
|
| 83 |
+
`make_corp_baseline.py` is a deterministic, training-free post-processing wrapper that composes ViTeX-14B's predicted text region back onto the source video. Two per-frame operations:
|
| 84 |
+
|
| 85 |
+
1. **Reinhard mean–variance LAB color matching** on a 20-px band just outside the mask, so the predicted glyphs match the source's local lighting.
|
| 86 |
+
2. **Signed-distance feathered alpha compositing** (4-px feather centered on the mask boundary), so the seam is smooth.
|
| 87 |
+
|
| 88 |
+
Inside the mask the result is the predicted glyphs (color-matched); outside the feather the result is byte-identical to the source. SeqAcc / CharAcc are within ~0.01 of raw ViTeX-14B (the predicted text region is unchanged), but PSNR / SSIM / LPIPS / DreamSim jump to near-Identity because the unedited region no longer pays the VAE round-trip penalty.
|
| 89 |
+
|
| 90 |
+
```bash
|
| 91 |
+
# Assumes you already have raw ViTeX-14B predictions in <pred_dir>/*.mp4
|
| 92 |
+
# and the eval split of ViTeX-Dataset under <data_root> (eval/original_videos/, eval/masks/).
|
| 93 |
+
python make_corp_baseline.py \
|
| 94 |
+
--records <data_root>/parsed_records.json \
|
| 95 |
+
--data_root <data_root> \
|
| 96 |
+
--pred_dir <raw_vitex14b_predictions_dir> \
|
| 97 |
+
--out_dir <output_dir_for_corp_baseline> \
|
| 98 |
+
--workers 8
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
CPU-only, runs in ~5 minutes on 8 workers for the 157-clip ViTeX-Bench evaluation split. Requires `ffmpeg` on `PATH`.
|
| 102 |
+
|
| 103 |
+
Reference: appendix G of the ViTeX-Bench paper.
|
| 104 |
+
|
| 105 |
## Limitations
|
| 106 |
|
| 107 |
- Trained on 230 samples; coverage of artistic fonts, complex backgrounds, and non-Latin scripts is limited.
|
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Build the ViTeX-14B (Corp) baseline.
|
| 2 |
+
|
| 3 |
+
For each test clip:
|
| 4 |
+
1. Read source video, ViTeX-14B prediction, and the dilated text mask.
|
| 5 |
+
2. Color-correct the prediction inside the mask to match the source by
|
| 6 |
+
Reinhard-style mean+std matching in LAB space, using a 20-px band just
|
| 7 |
+
outside the mask as the reference (so the local lighting is captured).
|
| 8 |
+
3. Composite onto the source with a signed-distance feathered alpha
|
| 9 |
+
centered on the mask edge so the seam is smooth.
|
| 10 |
+
|
| 11 |
+
The output is a 1280x720, 24 fps, 120-frame mp4 written under
|
| 12 |
+
baseline_output_videos/ViTeX-14B_Corp/<id>.mp4.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
import json
|
| 17 |
+
import os
|
| 18 |
+
import subprocess
|
| 19 |
+
from multiprocessing import Pool
|
| 20 |
+
|
| 21 |
+
import cv2
|
| 22 |
+
import numpy as np
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _read_video(path, max_frames=None):
|
| 26 |
+
cap = cv2.VideoCapture(path)
|
| 27 |
+
out = []
|
| 28 |
+
while True:
|
| 29 |
+
ok, f = cap.read()
|
| 30 |
+
if not ok:
|
| 31 |
+
break
|
| 32 |
+
out.append(cv2.cvtColor(f, cv2.COLOR_BGR2RGB))
|
| 33 |
+
if max_frames and len(out) >= max_frames:
|
| 34 |
+
break
|
| 35 |
+
cap.release()
|
| 36 |
+
return out
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _read_mask_video(path, target_h, target_w, max_frames=None):
|
| 40 |
+
cap = cv2.VideoCapture(path)
|
| 41 |
+
out = []
|
| 42 |
+
while True:
|
| 43 |
+
ok, f = cap.read()
|
| 44 |
+
if not ok:
|
| 45 |
+
break
|
| 46 |
+
gray = cv2.cvtColor(f, cv2.COLOR_BGR2GRAY)
|
| 47 |
+
if (gray.shape[0], gray.shape[1]) != (target_h, target_w):
|
| 48 |
+
gray = cv2.resize(gray, (target_w, target_h), interpolation=cv2.INTER_NEAREST)
|
| 49 |
+
out.append((gray > 127).astype(np.uint8))
|
| 50 |
+
if max_frames and len(out) >= max_frames:
|
| 51 |
+
break
|
| 52 |
+
cap.release()
|
| 53 |
+
return out
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _color_correct_lab(src_rgb, pred_rgb, mask_bin, band_width=20):
|
| 57 |
+
"""Reinhard-style LAB transfer using a band around the mask as reference."""
|
| 58 |
+
band = cv2.dilate(mask_bin, np.ones((band_width * 2 + 1, band_width * 2 + 1),
|
| 59 |
+
dtype=np.uint8)) - mask_bin
|
| 60 |
+
band_idx = band > 0
|
| 61 |
+
if band_idx.sum() < 100:
|
| 62 |
+
return pred_rgb # not enough reference, leave as-is
|
| 63 |
+
|
| 64 |
+
src_lab = cv2.cvtColor(src_rgb, cv2.COLOR_RGB2LAB).astype(np.float32)
|
| 65 |
+
pred_lab = cv2.cvtColor(pred_rgb, cv2.COLOR_RGB2LAB).astype(np.float32)
|
| 66 |
+
|
| 67 |
+
mean_src = src_lab[band_idx].mean(axis=0)
|
| 68 |
+
std_src = src_lab[band_idx].std(axis=0) + 1e-6
|
| 69 |
+
mean_pred = pred_lab[band_idx].mean(axis=0)
|
| 70 |
+
std_pred = pred_lab[band_idx].std(axis=0) + 1e-6
|
| 71 |
+
|
| 72 |
+
pred_corrected = (pred_lab - mean_pred) / std_pred * std_src + mean_src
|
| 73 |
+
pred_corrected = np.clip(pred_corrected, 0, 255).astype(np.uint8)
|
| 74 |
+
return cv2.cvtColor(pred_corrected, cv2.COLOR_LAB2RGB)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def _feathered_alpha(mask_bin, feather=4):
|
| 78 |
+
"""Smooth alpha centered on the mask boundary."""
|
| 79 |
+
sdf_in = cv2.distanceTransform(mask_bin, cv2.DIST_L2, 5)
|
| 80 |
+
sdf_out = cv2.distanceTransform(1 - mask_bin, cv2.DIST_L2, 5)
|
| 81 |
+
sdf = sdf_in - sdf_out
|
| 82 |
+
return np.clip((sdf + feather / 2.0) / feather, 0.0, 1.0).astype(np.float32)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def _process_frame(src_rgb, pred_rgb, mask_bin, band_width, feather):
|
| 86 |
+
pred_cc = _color_correct_lab(src_rgb, pred_rgb, mask_bin, band_width=band_width)
|
| 87 |
+
alpha = _feathered_alpha(mask_bin, feather=feather)[..., None]
|
| 88 |
+
out = src_rgb.astype(np.float32) * (1 - alpha) + pred_cc.astype(np.float32) * alpha
|
| 89 |
+
return out.astype(np.uint8)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def _encode_video(frames, out_path, fps=24):
|
| 93 |
+
if not frames:
|
| 94 |
+
raise RuntimeError("no frames to encode")
|
| 95 |
+
h, w = frames[0].shape[:2]
|
| 96 |
+
proc = subprocess.Popen([
|
| 97 |
+
"ffmpeg", "-y", "-loglevel", "error",
|
| 98 |
+
"-f", "rawvideo", "-pix_fmt", "rgb24",
|
| 99 |
+
"-s", f"{w}x{h}", "-r", str(fps),
|
| 100 |
+
"-i", "-",
|
| 101 |
+
"-c:v", "libx264", "-preset", "medium", "-crf", "18",
|
| 102 |
+
"-pix_fmt", "yuv420p", "-movflags", "+faststart",
|
| 103 |
+
out_path,
|
| 104 |
+
], stdin=subprocess.PIPE)
|
| 105 |
+
for f in frames:
|
| 106 |
+
proc.stdin.write(np.ascontiguousarray(f).tobytes())
|
| 107 |
+
proc.stdin.close()
|
| 108 |
+
if proc.wait() != 0:
|
| 109 |
+
raise RuntimeError(f"ffmpeg failed for {out_path}")
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def _process_clip(args):
|
| 113 |
+
rec, data_root, pred_dir, out_dir, target_frames, band_width, feather = args
|
| 114 |
+
vid = rec["id"]
|
| 115 |
+
out_path = os.path.join(out_dir, vid + ".mp4")
|
| 116 |
+
if os.path.exists(out_path):
|
| 117 |
+
return vid, "skip"
|
| 118 |
+
|
| 119 |
+
src_path = os.path.join(data_root, rec["original_video"])
|
| 120 |
+
mask_path = os.path.join(data_root, rec["mask_video"])
|
| 121 |
+
pred_path = os.path.join(pred_dir, vid + ".mp4")
|
| 122 |
+
if not (os.path.exists(src_path) and os.path.exists(mask_path) and os.path.exists(pred_path)):
|
| 123 |
+
return vid, "missing"
|
| 124 |
+
|
| 125 |
+
src_frames = _read_video(src_path, max_frames=target_frames)
|
| 126 |
+
pred_frames = _read_video(pred_path, max_frames=target_frames)
|
| 127 |
+
if not src_frames or not pred_frames:
|
| 128 |
+
return vid, "empty"
|
| 129 |
+
h, w = src_frames[0].shape[:2]
|
| 130 |
+
# Pred may be smaller (e.g., other res); resample to source grid.
|
| 131 |
+
pred_frames = [cv2.resize(f, (w, h), interpolation=cv2.INTER_LANCZOS4)
|
| 132 |
+
if (f.shape[0], f.shape[1]) != (h, w) else f
|
| 133 |
+
for f in pred_frames]
|
| 134 |
+
mask_frames = _read_mask_video(mask_path, target_h=h, target_w=w, max_frames=target_frames)
|
| 135 |
+
|
| 136 |
+
n = min(len(src_frames), len(pred_frames), len(mask_frames), target_frames)
|
| 137 |
+
out_frames = []
|
| 138 |
+
for t in range(n):
|
| 139 |
+
out_frames.append(_process_frame(
|
| 140 |
+
src_frames[t], pred_frames[t], mask_frames[t], band_width, feather,
|
| 141 |
+
))
|
| 142 |
+
_encode_video(out_frames, out_path, fps=24)
|
| 143 |
+
return vid, f"ok ({n}f)"
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def main():
|
| 147 |
+
ap = argparse.ArgumentParser()
|
| 148 |
+
ap.add_argument("--records", required=True)
|
| 149 |
+
ap.add_argument("--data_root", required=True)
|
| 150 |
+
ap.add_argument("--pred_dir", required=True,
|
| 151 |
+
help="Directory of ViTeX-14B raw predictions (e.g., ViTeX-14B_orig)")
|
| 152 |
+
ap.add_argument("--out_dir", required=True,
|
| 153 |
+
help="Where the corp baseline mp4s are written")
|
| 154 |
+
ap.add_argument("--target_frames", type=int, default=120)
|
| 155 |
+
ap.add_argument("--band_width", type=int, default=20,
|
| 156 |
+
help="Width in px of the reference band around the mask")
|
| 157 |
+
ap.add_argument("--feather", type=int, default=4,
|
| 158 |
+
help="Feather width in px centered on the mask edge")
|
| 159 |
+
ap.add_argument("--workers", type=int, default=8)
|
| 160 |
+
args = ap.parse_args()
|
| 161 |
+
|
| 162 |
+
os.makedirs(args.out_dir, exist_ok=True)
|
| 163 |
+
with open(args.records) as f:
|
| 164 |
+
records = json.load(f)
|
| 165 |
+
|
| 166 |
+
tasks = [(r, args.data_root, args.pred_dir, args.out_dir,
|
| 167 |
+
args.target_frames, args.band_width, args.feather)
|
| 168 |
+
for r in records]
|
| 169 |
+
|
| 170 |
+
n_ok, n_skip, n_miss, n_err = 0, 0, 0, 0
|
| 171 |
+
with Pool(args.workers) as p:
|
| 172 |
+
for i, (vid, status) in enumerate(p.imap_unordered(_process_clip, tasks), 1):
|
| 173 |
+
if status.startswith("ok"):
|
| 174 |
+
n_ok += 1
|
| 175 |
+
elif status == "skip":
|
| 176 |
+
n_skip += 1
|
| 177 |
+
elif status == "missing":
|
| 178 |
+
n_miss += 1
|
| 179 |
+
else:
|
| 180 |
+
n_err += 1
|
| 181 |
+
if i % 10 == 0 or i == len(tasks):
|
| 182 |
+
print(f" [{i}/{len(tasks)}] {vid}: {status}", flush=True)
|
| 183 |
+
print(f"\nDone: ok={n_ok} skipped={n_skip} missing={n_miss} errors={n_err}")
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
if __name__ == "__main__":
|
| 187 |
+
main()
|