ViTeX-Edit-14B / make_corp_baseline.py
Anonymous Authors
Rename displayed model name to ViTeX-Edit-14B in the model card
9b9fe26
"""Build the ViTeX-Edit-14B (Composite) baseline.
For each test clip:
1. Read source video, ViTeX-Edit-14B prediction, and the dilated text mask.
2. Color-correct the prediction inside the mask to match the source by
Reinhard-style mean+std matching in LAB space, using a 20-px band just
outside the mask as the reference (so the local lighting is captured).
3. Composite onto the source with a signed-distance feathered alpha
centered on the mask edge so the seam is smooth.
The output is a 1280x720, 24 fps, 120-frame mp4 written under
baseline_output_videos/ViTeX-14B_Corp/<id>.mp4.
"""
import argparse
import json
import os
import subprocess
from multiprocessing import Pool
import cv2
import numpy as np
def _read_video(path, max_frames=None):
cap = cv2.VideoCapture(path)
out = []
while True:
ok, f = cap.read()
if not ok:
break
out.append(cv2.cvtColor(f, cv2.COLOR_BGR2RGB))
if max_frames and len(out) >= max_frames:
break
cap.release()
return out
def _read_mask_video(path, target_h, target_w, max_frames=None):
cap = cv2.VideoCapture(path)
out = []
while True:
ok, f = cap.read()
if not ok:
break
gray = cv2.cvtColor(f, cv2.COLOR_BGR2GRAY)
if (gray.shape[0], gray.shape[1]) != (target_h, target_w):
gray = cv2.resize(gray, (target_w, target_h), interpolation=cv2.INTER_NEAREST)
out.append((gray > 127).astype(np.uint8))
if max_frames and len(out) >= max_frames:
break
cap.release()
return out
def _color_correct_lab(src_rgb, pred_rgb, mask_bin, band_width=20):
"""Reinhard-style LAB transfer using a band around the mask as reference."""
band = cv2.dilate(mask_bin, np.ones((band_width * 2 + 1, band_width * 2 + 1),
dtype=np.uint8)) - mask_bin
band_idx = band > 0
if band_idx.sum() < 100:
return pred_rgb # not enough reference, leave as-is
src_lab = cv2.cvtColor(src_rgb, cv2.COLOR_RGB2LAB).astype(np.float32)
pred_lab = cv2.cvtColor(pred_rgb, cv2.COLOR_RGB2LAB).astype(np.float32)
mean_src = src_lab[band_idx].mean(axis=0)
std_src = src_lab[band_idx].std(axis=0) + 1e-6
mean_pred = pred_lab[band_idx].mean(axis=0)
std_pred = pred_lab[band_idx].std(axis=0) + 1e-6
pred_corrected = (pred_lab - mean_pred) / std_pred * std_src + mean_src
pred_corrected = np.clip(pred_corrected, 0, 255).astype(np.uint8)
return cv2.cvtColor(pred_corrected, cv2.COLOR_LAB2RGB)
def _feathered_alpha(mask_bin, feather=4):
"""Smooth alpha centered on the mask boundary."""
sdf_in = cv2.distanceTransform(mask_bin, cv2.DIST_L2, 5)
sdf_out = cv2.distanceTransform(1 - mask_bin, cv2.DIST_L2, 5)
sdf = sdf_in - sdf_out
return np.clip((sdf + feather / 2.0) / feather, 0.0, 1.0).astype(np.float32)
def _process_frame(src_rgb, pred_rgb, mask_bin, band_width, feather):
pred_cc = _color_correct_lab(src_rgb, pred_rgb, mask_bin, band_width=band_width)
alpha = _feathered_alpha(mask_bin, feather=feather)[..., None]
out = src_rgb.astype(np.float32) * (1 - alpha) + pred_cc.astype(np.float32) * alpha
return out.astype(np.uint8)
def _encode_video(frames, out_path, fps=24):
if not frames:
raise RuntimeError("no frames to encode")
h, w = frames[0].shape[:2]
proc = subprocess.Popen([
"ffmpeg", "-y", "-loglevel", "error",
"-f", "rawvideo", "-pix_fmt", "rgb24",
"-s", f"{w}x{h}", "-r", str(fps),
"-i", "-",
"-c:v", "libx264", "-preset", "medium", "-crf", "18",
"-pix_fmt", "yuv420p", "-movflags", "+faststart",
out_path,
], stdin=subprocess.PIPE)
for f in frames:
proc.stdin.write(np.ascontiguousarray(f).tobytes())
proc.stdin.close()
if proc.wait() != 0:
raise RuntimeError(f"ffmpeg failed for {out_path}")
def _process_clip(args):
rec, data_root, pred_dir, out_dir, target_frames, band_width, feather = args
vid = rec["id"]
out_path = os.path.join(out_dir, vid + ".mp4")
if os.path.exists(out_path):
return vid, "skip"
src_path = os.path.join(data_root, rec["original_video"])
mask_path = os.path.join(data_root, rec["mask_video"])
pred_path = os.path.join(pred_dir, vid + ".mp4")
if not (os.path.exists(src_path) and os.path.exists(mask_path) and os.path.exists(pred_path)):
return vid, "missing"
src_frames = _read_video(src_path, max_frames=target_frames)
pred_frames = _read_video(pred_path, max_frames=target_frames)
if not src_frames or not pred_frames:
return vid, "empty"
h, w = src_frames[0].shape[:2]
# Pred may be smaller (e.g., other res); resample to source grid.
pred_frames = [cv2.resize(f, (w, h), interpolation=cv2.INTER_LANCZOS4)
if (f.shape[0], f.shape[1]) != (h, w) else f
for f in pred_frames]
mask_frames = _read_mask_video(mask_path, target_h=h, target_w=w, max_frames=target_frames)
n = min(len(src_frames), len(pred_frames), len(mask_frames), target_frames)
out_frames = []
for t in range(n):
out_frames.append(_process_frame(
src_frames[t], pred_frames[t], mask_frames[t], band_width, feather,
))
_encode_video(out_frames, out_path, fps=24)
return vid, f"ok ({n}f)"
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--records", required=True)
ap.add_argument("--data_root", required=True)
ap.add_argument("--pred_dir", required=True,
help="Directory of ViTeX-Edit-14B raw predictions (e.g., ViTeX-Edit-14B_orig)")
ap.add_argument("--out_dir", required=True,
help="Where the corp baseline mp4s are written")
ap.add_argument("--target_frames", type=int, default=120)
ap.add_argument("--band_width", type=int, default=20,
help="Width in px of the reference band around the mask")
ap.add_argument("--feather", type=int, default=4,
help="Feather width in px centered on the mask edge")
ap.add_argument("--workers", type=int, default=8)
args = ap.parse_args()
os.makedirs(args.out_dir, exist_ok=True)
with open(args.records) as f:
records = json.load(f)
tasks = [(r, args.data_root, args.pred_dir, args.out_dir,
args.target_frames, args.band_width, args.feather)
for r in records]
n_ok, n_skip, n_miss, n_err = 0, 0, 0, 0
with Pool(args.workers) as p:
for i, (vid, status) in enumerate(p.imap_unordered(_process_clip, tasks), 1):
if status.startswith("ok"):
n_ok += 1
elif status == "skip":
n_skip += 1
elif status == "missing":
n_miss += 1
else:
n_err += 1
if i % 10 == 0 or i == len(tasks):
print(f" [{i}/{len(tasks)}] {vid}: {status}", flush=True)
print(f"\nDone: ok={n_ok} skipped={n_skip} missing={n_miss} errors={n_err}")
if __name__ == "__main__":
main()