opd_zt / scripts /build_coldstart_cot.py
sdzt's picture
Add files using upload-large-folder tool
bf46e5d verified
Raw
History Blame Contribute Delete
16.6 kB
"""
Build video cold-start CoT data in the OMNEX-VL ("Ground What You See") format.
Pipeline (mirrors LLaVA-CoT construction, but for video):
- Source: TemporalBench short_qa (MCQ, clean GT letter). Optional top-up: long_qa.
- Teacher: Qwen2.5-VL-72B-Instruct via vLLM, sees the real video.
- The GT answer is PROVIDED to the teacher (LLaVA-CoT style) so it writes a
plausible 4-stage trace that lands on GT. GT is used only at generation time;
it is NOT part of the stored training prompt.
- Output 4 stages with tags: <prethink> <caption> <think> <answer>.
- Rejection sampling: keep only rows where all four tags parse AND the
extracted <answer> matches GT (case-insensitive letter match).
Output (OMNEX sft.py schema):
data/coldstart_cot_10k.json -> list of {problem, data_type, path, process_and_answer, meta}
also writes incrementally to data/coldstart_cot.raw.jsonl (resume-safe).
Usage:
python scripts/build_coldstart_cot.py --limit 8 # smoke test
python scripts/build_coldstart_cot.py --target 10000 # full run
"""
from __future__ import annotations
import argparse
import json
import os
import re
import sys
import time
from pathlib import Path
import numpy as np
from PIL import Image
ROOT = Path("/mnt/local-fast/opd_zt")
TB_RAW = ROOT / "data" / "raw" / "tb"
TB_VID = ROOT / "data" / "videos" / "tb"
# ---------------------------------------------------------------------------
# OMNEX-VL prompt templates (verbatim from train/.../cold-start/train/sft.py)
SYSTEM_MESSAGE = "You are a helpful assistant"
QUESTION_TEMPLATE = (
"{Question}\n"
"Please carefully analyze the pictures (or videos) and problems according to the following requirements"
"In <prethink> </prethink> tags, carefully analyze the problem and briefly explain the steps to explain the problem and the key thinking direction of reasoning the problem"
"In <caption> </caption> tags, Please describe the image carefully, paying special attention to the details related to the problem and the reasoning direction of solving the problem"
"In <think> </think> tags, outline a step-by-step thought process you would use to solve the problem based on the image"
"In <answer> </answer> tags, give the final answer in a direct format, and it must match the correct answer exactly."
"Please sort out the output in the format of '<prethink>...</prethink>\n<caption>...</caption>\n<think>...</think>\n<answer>...</answer>' according to the above requirements"
)
TYPE_TEMPLATE = {
"multiple choice": " Please provide only the single option letter (e.g., A, B, C, D, etc.) within the <answer> </answer> tags.",
"numerical": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
}
# Extra instruction shown ONLY to the teacher at generation time (stripped from
# the stored training prompt). Provides the gold answer, LLaVA-CoT style.
GEN_GT_HINT = (
"\n[Supervision only — do NOT mention that the answer was given] "
"The verified correct answer is: {GT}. "
"Write the <prethink>/<caption>/<think> so the reasoning genuinely supports this, "
"and make <answer> exactly {GT}."
)
TAG_RE = {
t: re.compile(rf"<{t}>(.*?)</{t}>", re.DOTALL)
for t in ("prethink", "caption", "think", "answer")
}
def parse_stages(text: str) -> dict | None:
"""Return {prethink,caption,think,answer} if all four tags present, else None."""
out = {}
for t, rx in TAG_RE.items():
m = rx.search(text)
if not m:
return None
out[t] = m.group(1).strip()
if not all(out.values()):
return None
return out
# phrases that reveal the GT was handed to the teacher (LLaVA-CoT leakage)
LEAK_PATTERNS = [
"correct answer is provided", "answer is given", "answer was given",
"as given", "as provided", "provided answer", "given answer",
"supervision", "you told me", "since the answer is", "we are told",
"the correct answer, which is", "verified correct answer", "do not mention",
"[supervision", "told that the answer",
]
def _max_ngram_repeat(text: str, n: int = 4) -> int:
toks = text.split()
if len(toks) < n:
return 0
from collections import Counter
grams = Counter(tuple(toks[i:i + n]) for i in range(len(toks) - n + 1))
return max(grams.values()) if grams else 0
def quality_ok(stages: dict, reason: list) -> bool:
"""Lightweight quality floor for cold-start traces. Mutates `reason` with the failure."""
full = " ".join(stages.values()).lower()
for p in LEAK_PATTERNS:
if p in full:
reason.append(f"leak:{p}")
return False
# min substance per stage (word counts)
if len(stages["caption"].split()) < 12:
reason.append("short_caption"); return False
if len(stages["think"].split()) < 12:
reason.append("short_think"); return False
if len(stages["prethink"].split()) < 6:
reason.append("short_prethink"); return False
# repetition / degeneration
for k in ("caption", "think"):
if _max_ngram_repeat(stages[k], 4) >= 4:
reason.append(f"repeat_{k}"); return False
return True
def norm_letter(s: str) -> str:
"""Normalize an MCQ answer to a bare uppercase letter when possible."""
s = s.strip()
m = re.match(r"^\(?([A-Za-z])\)?\b", s)
return m.group(1).upper() if m else s.upper()
# ---------------------------------------------------------------------------
# data loading
def load_source_rows(source: str) -> list[dict]:
"""Load TB rows for a source ('short_qa' | 'long_qa'), only those with a video on disk."""
path = TB_RAW / f"temporalbench_{source}.json"
data = json.load(open(path))
rows = []
for r in data:
vn = r["video_name"]
vabs = TB_VID / vn
if not vabs.exists():
continue
rows.append({
"idx": r["idx"],
"video_name": vn,
"video_abs": str(vabs),
"question": r["question"],
"GT": str(r["GT"]).strip(),
"source": source,
})
return rows
def make_video_field(video_abs: str, fps: float, max_frames: int, min_frames: int,
nframes: int, max_pixels: int) -> dict:
"""Build the qwen_vl_utils video dict. fps-mode when nframes<=0 (fps + frame cap)."""
field = {"type": "video", "video": "file://" + video_abs, "max_pixels": max_pixels}
if nframes and nframes > 0:
field.update({"nframes": nframes, "max_frames": nframes, "min_frames": nframes})
else:
field.update({"fps": fps, "max_frames": max_frames, "min_frames": min_frames})
return field
def load_video_frames(video_abs: str, fps: float, max_frames: int, min_frames: int,
nframes: int, max_pixels: int):
"""Decode a video to a list of PIL frames using qwen_vl_utils (same as compute_grounding)."""
from qwen_vl_utils import process_vision_info
vfield = make_video_field(video_abs, fps, max_frames, min_frames, nframes, max_pixels)
msg = [{
"role": "user",
"content": [vfield, {"type": "text", "text": ""}],
}]
_, video_inputs, video_kwargs = process_vision_info(msg, return_video_kwargs=True)
if not video_inputs:
return [], {}
vid = video_inputs[0]
if hasattr(vid, "permute"):
import torch
if vid.dtype != torch.uint8:
vid = vid.clamp(0, 255).to(torch.uint8)
frames = [Image.fromarray(vid[i].permute(1, 2, 0).cpu().numpy(), mode="RGB")
for i in range(vid.shape[0])]
else:
arr = np.asarray(vid)
frames = [Image.fromarray(arr[i], mode="RGB") for i in range(arr.shape[0])]
return frames, video_kwargs
def build_train_user_text(question: str) -> str:
"""User prompt — EXACTLY OMNEX-VL sft.py: QUESTION_TEMPLATE only (TYPE_TEMPLATE is
defined but never appended in their prepare_dataset). TB questions already carry the
'Answer with the option's letter...' instruction, so no suffix is needed."""
return QUESTION_TEMPLATE.format(Question=question)
# ---------------------------------------------------------------------------
def parse_args():
ap = argparse.ArgumentParser()
ap.add_argument("--teacher", default="Qwen/Qwen2.5-VL-72B-Instruct")
ap.add_argument("--tp_size", type=int, default=8)
ap.add_argument("--gpu_memory_utilization", type=float, default=0.90)
ap.add_argument("--max_model_len", type=int, default=16384)
ap.add_argument("--fps", type=float, default=2.0, help="frames per second (fps-mode)")
ap.add_argument("--max_frames", type=int, default=64, help="cap frames in fps-mode")
ap.add_argument("--min_frames", type=int, default=4)
ap.add_argument("--nframes", type=int, default=0, help=">0 forces fixed frame count instead of fps-mode")
ap.add_argument("--max_pixels", type=int, default=360 * 420)
ap.add_argument("--batch_size", type=int, default=48)
ap.add_argument("--temperature", type=float, default=0.6)
ap.add_argument("--top_p", type=float, default=0.9)
ap.add_argument("--max_tokens", type=int, default=1024)
ap.add_argument("--target", type=int, default=10000, help="stop once this many valid samples collected")
ap.add_argument("--limit", type=int, default=0, help="cap number of source rows attempted (smoke test)")
ap.add_argument("--sources", nargs="+", default=["short_qa", "long_qa"],
help="process in order; long_qa used as top-up")
ap.add_argument("--seed", type=int, default=0)
ap.add_argument("--out_jsonl", default=str(ROOT / "data" / "coldstart_cot.raw.jsonl"))
ap.add_argument("--out_json", default=str(ROOT / "data" / "coldstart_cot_10k.json"))
return ap.parse_args()
def main() -> int:
args = parse_args()
rng = np.random.default_rng(args.seed)
# --- gather source rows ---
all_rows: list[dict] = []
for src in args.sources:
sr = load_source_rows(src)
all_rows.extend(sr)
print(f"[gen] source {src}: {len(sr)} rows with on-disk video", flush=True)
# Group questions by video and emit same-video questions consecutively so the
# shared [system + <video tokens>] prefix hits vLLM prefix cache (~4.5 Q/video),
# reusing the expensive vision prefill instead of recomputing it per question.
from collections import defaultdict
by_vid: dict[str, list[dict]] = defaultdict(list)
for r in all_rows:
by_vid[r["video_name"]].append(r)
vids = list(by_vid.keys())
rng.shuffle(vids)
rows = [r for v in vids for r in by_vid[v]]
if args.limit:
rows = rows[: args.limit]
print(f"[gen] total candidate rows: {len(rows)} across {len(vids)} videos "
f"(~{len(rows)/max(len(vids),1):.1f} Q/video) | target valid: {args.target}", flush=True)
# --- resume: skip idx already in out_jsonl ---
out_jsonl = Path(args.out_jsonl)
done_idx: set[str] = set()
n_valid = 0
if out_jsonl.exists():
for line in out_jsonl.open():
try:
r = json.loads(line)
except Exception:
continue
done_idx.add(r["meta"]["idx"])
if r.get("valid"):
n_valid += 1
print(f"[gen] resume: {len(done_idx)} already processed, {n_valid} valid", flush=True)
rows = [r for r in rows if r["idx"] not in done_idx]
# --- load teacher ---
print(f"[gen] loading vLLM teacher {args.teacher} (TP={args.tp_size})", flush=True)
from vllm import LLM, SamplingParams
from transformers import AutoProcessor
processor = AutoProcessor.from_pretrained(args.teacher, trust_remote_code=True)
llm = LLM(
model=args.teacher,
tensor_parallel_size=args.tp_size,
max_model_len=args.max_model_len,
gpu_memory_utilization=args.gpu_memory_utilization,
trust_remote_code=True,
limit_mm_per_prompt={"image": 0, "video": 1},
enforce_eager=False,
dtype="bfloat16",
enable_prefix_caching=True, # reuse shared [system+video] prefill across same-video Qs
)
sp = SamplingParams(temperature=args.temperature, top_p=args.top_p, max_tokens=args.max_tokens)
frame_cache: dict[str, tuple] = {} # video_name -> (frames, video_kwargs)
fout = out_jsonl.open("a")
t0 = time.time()
n_attempt = 0
def flush_batch(batch: list[dict]):
nonlocal n_valid, n_attempt
# decode frames (cached per video) + build prompts
reqs, metas = [], []
for b in batch:
vn = b["video_name"]
if vn not in frame_cache:
try:
frame_cache[vn] = load_video_frames(
b["video_abs"], args.fps, args.max_frames, args.min_frames,
args.nframes, args.max_pixels)
except Exception as e:
frame_cache[vn] = ([], {})
frames, vkw = frame_cache[vn]
if not frames:
fout.write(json.dumps({"valid": False, "reason": "decode", "meta": b}) + "\n")
continue
gen_user = build_train_user_text(b["question"]) + GEN_GT_HINT.format(GT=b["GT"])
msgs = [
{"role": "system", "content": [{"type": "text", "text": SYSTEM_MESSAGE}]},
{"role": "user", "content": [{"type": "video"}, {"type": "text", "text": gen_user}]},
]
text = processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
reqs.append({"prompt": text, "multi_modal_data": {"video": frames}, "mm_processor_kwargs": vkw})
metas.append(b)
if not reqs:
return
try:
outs = llm.generate(reqs, sampling_params=sp, use_tqdm=False)
except Exception as e:
for b in metas:
fout.write(json.dumps({"valid": False, "reason": f"vllm:{e}", "meta": b}) + "\n")
return
for out, b in zip(outs, metas):
n_attempt += 1
gen = out.outputs[0].text
stages = parse_stages(gen)
reason: list[str] = []
if stages is None:
ok = False; reason.append("no_tags")
elif norm_letter(stages["answer"]) != norm_letter(b["GT"]):
ok = False; reason.append("answer_mismatch")
elif not quality_ok(stages, reason):
ok = False
else:
ok = True
rec = {"valid": bool(ok), "meta": b, "gen": gen}
if ok:
rec["stages"] = stages
n_valid += 1
else:
rec["reason"] = ";".join(reason) or "unknown"
fout.write(json.dumps(rec, ensure_ascii=False) + "\n")
fout.flush()
for i in range(0, len(rows), args.batch_size):
if n_valid >= args.target:
print(f"[gen] reached target {args.target}, stopping.", flush=True)
break
flush_batch(rows[i: i + args.batch_size])
rate = n_valid / max(1, n_attempt)
el = time.time() - t0
print(f"[gen] attempted={n_attempt} valid={n_valid} ({rate:.1%}) "
f"elapsed={el/60:.1f}m {n_attempt/max(el,1):.2f} it/s", flush=True)
fout.close()
compile_json(out_jsonl, Path(args.out_json), args.target)
return 0
def compile_json(raw_jsonl: Path, out_json: Path, target: int):
"""Turn valid raw records into OMNEX sft.py training schema."""
samples = []
for line in raw_jsonl.open():
try:
r = json.loads(line)
except Exception:
continue
if not r.get("valid"):
continue
b = r["meta"]
s = r["stages"]
process_and_answer = (
f"<prethink>{s['prethink']}</prethink>\n"
f"<caption>{s['caption']}</caption>\n"
f"<think>{s['think']}</think>\n"
f"<answer>{s['answer']}</answer>"
)
samples.append({
"problem": b["question"],
"data_type": "video",
"path": b["video_abs"],
"process_and_answer": process_and_answer,
"meta": {"idx": b["idx"], "source": b["source"], "GT": b["GT"]},
})
if len(samples) >= target:
break
json.dump(samples, open(out_json, "w"), ensure_ascii=False, indent=1)
print(f"[gen] wrote {len(samples)} samples -> {out_json}", flush=True)
if __name__ == "__main__":
sys.exit(main())