ViTeX-Bench commited on
Commit
78af9f2
·
1 Parent(s): 4932bf3

Add ViTeX-14B (Corp): training-free locality-preserving post-processing wrapper

Browse files

`make_corp_baseline.py` composes a raw ViTeX-14B prediction's text region back
onto the source video with two per-frame operations:

1. Reinhard mean-variance LAB color matching on a 20-px band just outside
the mask, so the predicted glyphs match the source's local lighting.
2. Signed-distance feathered alpha compositing (4-px feather centered on
the mask boundary), so the boundary has no visible seam.

Inside the mask the result is the predicted glyphs (color-matched); outside
the feather the result is byte-identical to the source. SeqAcc / CharAcc /
TTS are within ~0.01 of raw ViTeX-14B (the predicted text region is
unchanged), while PSNR / SSIM / LPIPS / DreamSim jump to near-Identity.

CPU-only, ~5 min on 8 workers for the 157-clip evaluation split. Reference:
appendix G of the ViTeX-Bench paper.

Files changed (2) hide show
  1. README.md +26 -1
  2. make_corp_baseline.py +187 -0
README.md CHANGED
@@ -32,7 +32,8 @@ This repository is fully self-contained — it bundles the trained weights, the
32
  .
33
  ├── README.md
34
  ├── requirements.txt
35
- ├── inference_example.py
 
36
  ├── vitex_14b.safetensors (8 GB — trained adapter weights)
37
  ├── diffsynth/ (bundled inference library)
38
  └── base_model/ (70 GB — frozen base model files)
@@ -77,6 +78,30 @@ python inference_example.py \
77
 
78
  The script automatically uses the bundled `base_model/` and `vitex_14b.safetensors` — no extra downloads.
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  ## Limitations
81
 
82
  - Trained on 230 samples; coverage of artistic fonts, complex backgrounds, and non-Latin scripts is limited.
 
32
  .
33
  ├── README.md
34
  ├── requirements.txt
35
+ ├── inference_example.py run ViTeX-14B on one (video, mask, glyph) tuple
36
+ ├── make_corp_baseline.py build the ViTeX-14B (Corp) variant from raw predictions
37
  ├── vitex_14b.safetensors (8 GB — trained adapter weights)
38
  ├── diffsynth/ (bundled inference library)
39
  └── base_model/ (70 GB — frozen base model files)
 
78
 
79
  The script automatically uses the bundled `base_model/` and `vitex_14b.safetensors` — no extra downloads.
80
 
81
+ ## Locality-preserving variant: ViTeX-14B (Corp)
82
+
83
+ `make_corp_baseline.py` is a deterministic, training-free post-processing wrapper that composes ViTeX-14B's predicted text region back onto the source video. Two per-frame operations:
84
+
85
+ 1. **Reinhard mean–variance LAB color matching** on a 20-px band just outside the mask, so the predicted glyphs match the source's local lighting.
86
+ 2. **Signed-distance feathered alpha compositing** (4-px feather centered on the mask boundary), so the seam is smooth.
87
+
88
+ Inside the mask the result is the predicted glyphs (color-matched); outside the feather the result is byte-identical to the source. SeqAcc / CharAcc are within ~0.01 of raw ViTeX-14B (the predicted text region is unchanged), but PSNR / SSIM / LPIPS / DreamSim jump to near-Identity because the unedited region no longer pays the VAE round-trip penalty.
89
+
90
+ ```bash
91
+ # Assumes you already have raw ViTeX-14B predictions in <pred_dir>/*.mp4
92
+ # and the eval split of ViTeX-Dataset under <data_root> (eval/original_videos/, eval/masks/).
93
+ python make_corp_baseline.py \
94
+ --records <data_root>/parsed_records.json \
95
+ --data_root <data_root> \
96
+ --pred_dir <raw_vitex14b_predictions_dir> \
97
+ --out_dir <output_dir_for_corp_baseline> \
98
+ --workers 8
99
+ ```
100
+
101
+ CPU-only, runs in ~5 minutes on 8 workers for the 157-clip ViTeX-Bench evaluation split. Requires `ffmpeg` on `PATH`.
102
+
103
+ Reference: appendix G of the ViTeX-Bench paper.
104
+
105
  ## Limitations
106
 
107
  - Trained on 230 samples; coverage of artistic fonts, complex backgrounds, and non-Latin scripts is limited.
make_corp_baseline.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Build the ViTeX-14B (Corp) baseline.
2
+
3
+ For each test clip:
4
+ 1. Read source video, ViTeX-14B prediction, and the dilated text mask.
5
+ 2. Color-correct the prediction inside the mask to match the source by
6
+ Reinhard-style mean+std matching in LAB space, using a 20-px band just
7
+ outside the mask as the reference (so the local lighting is captured).
8
+ 3. Composite onto the source with a signed-distance feathered alpha
9
+ centered on the mask edge so the seam is smooth.
10
+
11
+ The output is a 1280x720, 24 fps, 120-frame mp4 written under
12
+ baseline_output_videos/ViTeX-14B_Corp/<id>.mp4.
13
+ """
14
+
15
+ import argparse
16
+ import json
17
+ import os
18
+ import subprocess
19
+ from multiprocessing import Pool
20
+
21
+ import cv2
22
+ import numpy as np
23
+
24
+
25
+ def _read_video(path, max_frames=None):
26
+ cap = cv2.VideoCapture(path)
27
+ out = []
28
+ while True:
29
+ ok, f = cap.read()
30
+ if not ok:
31
+ break
32
+ out.append(cv2.cvtColor(f, cv2.COLOR_BGR2RGB))
33
+ if max_frames and len(out) >= max_frames:
34
+ break
35
+ cap.release()
36
+ return out
37
+
38
+
39
+ def _read_mask_video(path, target_h, target_w, max_frames=None):
40
+ cap = cv2.VideoCapture(path)
41
+ out = []
42
+ while True:
43
+ ok, f = cap.read()
44
+ if not ok:
45
+ break
46
+ gray = cv2.cvtColor(f, cv2.COLOR_BGR2GRAY)
47
+ if (gray.shape[0], gray.shape[1]) != (target_h, target_w):
48
+ gray = cv2.resize(gray, (target_w, target_h), interpolation=cv2.INTER_NEAREST)
49
+ out.append((gray > 127).astype(np.uint8))
50
+ if max_frames and len(out) >= max_frames:
51
+ break
52
+ cap.release()
53
+ return out
54
+
55
+
56
+ def _color_correct_lab(src_rgb, pred_rgb, mask_bin, band_width=20):
57
+ """Reinhard-style LAB transfer using a band around the mask as reference."""
58
+ band = cv2.dilate(mask_bin, np.ones((band_width * 2 + 1, band_width * 2 + 1),
59
+ dtype=np.uint8)) - mask_bin
60
+ band_idx = band > 0
61
+ if band_idx.sum() < 100:
62
+ return pred_rgb # not enough reference, leave as-is
63
+
64
+ src_lab = cv2.cvtColor(src_rgb, cv2.COLOR_RGB2LAB).astype(np.float32)
65
+ pred_lab = cv2.cvtColor(pred_rgb, cv2.COLOR_RGB2LAB).astype(np.float32)
66
+
67
+ mean_src = src_lab[band_idx].mean(axis=0)
68
+ std_src = src_lab[band_idx].std(axis=0) + 1e-6
69
+ mean_pred = pred_lab[band_idx].mean(axis=0)
70
+ std_pred = pred_lab[band_idx].std(axis=0) + 1e-6
71
+
72
+ pred_corrected = (pred_lab - mean_pred) / std_pred * std_src + mean_src
73
+ pred_corrected = np.clip(pred_corrected, 0, 255).astype(np.uint8)
74
+ return cv2.cvtColor(pred_corrected, cv2.COLOR_LAB2RGB)
75
+
76
+
77
+ def _feathered_alpha(mask_bin, feather=4):
78
+ """Smooth alpha centered on the mask boundary."""
79
+ sdf_in = cv2.distanceTransform(mask_bin, cv2.DIST_L2, 5)
80
+ sdf_out = cv2.distanceTransform(1 - mask_bin, cv2.DIST_L2, 5)
81
+ sdf = sdf_in - sdf_out
82
+ return np.clip((sdf + feather / 2.0) / feather, 0.0, 1.0).astype(np.float32)
83
+
84
+
85
+ def _process_frame(src_rgb, pred_rgb, mask_bin, band_width, feather):
86
+ pred_cc = _color_correct_lab(src_rgb, pred_rgb, mask_bin, band_width=band_width)
87
+ alpha = _feathered_alpha(mask_bin, feather=feather)[..., None]
88
+ out = src_rgb.astype(np.float32) * (1 - alpha) + pred_cc.astype(np.float32) * alpha
89
+ return out.astype(np.uint8)
90
+
91
+
92
+ def _encode_video(frames, out_path, fps=24):
93
+ if not frames:
94
+ raise RuntimeError("no frames to encode")
95
+ h, w = frames[0].shape[:2]
96
+ proc = subprocess.Popen([
97
+ "ffmpeg", "-y", "-loglevel", "error",
98
+ "-f", "rawvideo", "-pix_fmt", "rgb24",
99
+ "-s", f"{w}x{h}", "-r", str(fps),
100
+ "-i", "-",
101
+ "-c:v", "libx264", "-preset", "medium", "-crf", "18",
102
+ "-pix_fmt", "yuv420p", "-movflags", "+faststart",
103
+ out_path,
104
+ ], stdin=subprocess.PIPE)
105
+ for f in frames:
106
+ proc.stdin.write(np.ascontiguousarray(f).tobytes())
107
+ proc.stdin.close()
108
+ if proc.wait() != 0:
109
+ raise RuntimeError(f"ffmpeg failed for {out_path}")
110
+
111
+
112
+ def _process_clip(args):
113
+ rec, data_root, pred_dir, out_dir, target_frames, band_width, feather = args
114
+ vid = rec["id"]
115
+ out_path = os.path.join(out_dir, vid + ".mp4")
116
+ if os.path.exists(out_path):
117
+ return vid, "skip"
118
+
119
+ src_path = os.path.join(data_root, rec["original_video"])
120
+ mask_path = os.path.join(data_root, rec["mask_video"])
121
+ pred_path = os.path.join(pred_dir, vid + ".mp4")
122
+ if not (os.path.exists(src_path) and os.path.exists(mask_path) and os.path.exists(pred_path)):
123
+ return vid, "missing"
124
+
125
+ src_frames = _read_video(src_path, max_frames=target_frames)
126
+ pred_frames = _read_video(pred_path, max_frames=target_frames)
127
+ if not src_frames or not pred_frames:
128
+ return vid, "empty"
129
+ h, w = src_frames[0].shape[:2]
130
+ # Pred may be smaller (e.g., other res); resample to source grid.
131
+ pred_frames = [cv2.resize(f, (w, h), interpolation=cv2.INTER_LANCZOS4)
132
+ if (f.shape[0], f.shape[1]) != (h, w) else f
133
+ for f in pred_frames]
134
+ mask_frames = _read_mask_video(mask_path, target_h=h, target_w=w, max_frames=target_frames)
135
+
136
+ n = min(len(src_frames), len(pred_frames), len(mask_frames), target_frames)
137
+ out_frames = []
138
+ for t in range(n):
139
+ out_frames.append(_process_frame(
140
+ src_frames[t], pred_frames[t], mask_frames[t], band_width, feather,
141
+ ))
142
+ _encode_video(out_frames, out_path, fps=24)
143
+ return vid, f"ok ({n}f)"
144
+
145
+
146
+ def main():
147
+ ap = argparse.ArgumentParser()
148
+ ap.add_argument("--records", required=True)
149
+ ap.add_argument("--data_root", required=True)
150
+ ap.add_argument("--pred_dir", required=True,
151
+ help="Directory of ViTeX-14B raw predictions (e.g., ViTeX-14B_orig)")
152
+ ap.add_argument("--out_dir", required=True,
153
+ help="Where the corp baseline mp4s are written")
154
+ ap.add_argument("--target_frames", type=int, default=120)
155
+ ap.add_argument("--band_width", type=int, default=20,
156
+ help="Width in px of the reference band around the mask")
157
+ ap.add_argument("--feather", type=int, default=4,
158
+ help="Feather width in px centered on the mask edge")
159
+ ap.add_argument("--workers", type=int, default=8)
160
+ args = ap.parse_args()
161
+
162
+ os.makedirs(args.out_dir, exist_ok=True)
163
+ with open(args.records) as f:
164
+ records = json.load(f)
165
+
166
+ tasks = [(r, args.data_root, args.pred_dir, args.out_dir,
167
+ args.target_frames, args.band_width, args.feather)
168
+ for r in records]
169
+
170
+ n_ok, n_skip, n_miss, n_err = 0, 0, 0, 0
171
+ with Pool(args.workers) as p:
172
+ for i, (vid, status) in enumerate(p.imap_unordered(_process_clip, tasks), 1):
173
+ if status.startswith("ok"):
174
+ n_ok += 1
175
+ elif status == "skip":
176
+ n_skip += 1
177
+ elif status == "missing":
178
+ n_miss += 1
179
+ else:
180
+ n_err += 1
181
+ if i % 10 == 0 or i == len(tasks):
182
+ print(f" [{i}/{len(tasks)}] {vid}: {status}", flush=True)
183
+ print(f"\nDone: ok={n_ok} skipped={n_skip} missing={n_miss} errors={n_err}")
184
+
185
+
186
+ if __name__ == "__main__":
187
+ main()