| """ |
| Build video cold-start CoT data in the OMNEX-VL ("Ground What You See") format. |
| |
| Pipeline (mirrors LLaVA-CoT construction, but for video): |
| - Source: TemporalBench short_qa (MCQ, clean GT letter). Optional top-up: long_qa. |
| - Teacher: Qwen2.5-VL-72B-Instruct via vLLM, sees the real video. |
| - The GT answer is PROVIDED to the teacher (LLaVA-CoT style) so it writes a |
| plausible 4-stage trace that lands on GT. GT is used only at generation time; |
| it is NOT part of the stored training prompt. |
| - Output 4 stages with tags: <prethink> <caption> <think> <answer>. |
| - Rejection sampling: keep only rows where all four tags parse AND the |
| extracted <answer> matches GT (case-insensitive letter match). |
| |
| Output (OMNEX sft.py schema): |
| data/coldstart_cot_10k.json -> list of {problem, data_type, path, process_and_answer, meta} |
| also writes incrementally to data/coldstart_cot.raw.jsonl (resume-safe). |
| |
| Usage: |
| python scripts/build_coldstart_cot.py --limit 8 # smoke test |
| python scripts/build_coldstart_cot.py --target 10000 # full run |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| import re |
| import sys |
| import time |
| from pathlib import Path |
|
|
| import numpy as np |
| from PIL import Image |
|
|
| ROOT = Path("/mnt/local-fast/opd_zt") |
| TB_RAW = ROOT / "data" / "raw" / "tb" |
| TB_VID = ROOT / "data" / "videos" / "tb" |
|
|
| |
| |
|
|
| SYSTEM_MESSAGE = "You are a helpful assistant" |
|
|
| QUESTION_TEMPLATE = ( |
| "{Question}\n" |
| "Please carefully analyze the pictures (or videos) and problems according to the following requirements" |
| "In <prethink> </prethink> tags, carefully analyze the problem and briefly explain the steps to explain the problem and the key thinking direction of reasoning the problem" |
| "In <caption> </caption> tags, Please describe the image carefully, paying special attention to the details related to the problem and the reasoning direction of solving the problem" |
| "In <think> </think> tags, outline a step-by-step thought process you would use to solve the problem based on the image" |
| "In <answer> </answer> tags, give the final answer in a direct format, and it must match the correct answer exactly." |
| "Please sort out the output in the format of '<prethink>...</prethink>\n<caption>...</caption>\n<think>...</think>\n<answer>...</answer>' according to the above requirements" |
| ) |
|
|
| TYPE_TEMPLATE = { |
| "multiple choice": " Please provide only the single option letter (e.g., A, B, C, D, etc.) within the <answer> </answer> tags.", |
| "numerical": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.", |
| } |
|
|
| |
| |
| GEN_GT_HINT = ( |
| "\n[Supervision only — do NOT mention that the answer was given] " |
| "The verified correct answer is: {GT}. " |
| "Write the <prethink>/<caption>/<think> so the reasoning genuinely supports this, " |
| "and make <answer> exactly {GT}." |
| ) |
|
|
| TAG_RE = { |
| t: re.compile(rf"<{t}>(.*?)</{t}>", re.DOTALL) |
| for t in ("prethink", "caption", "think", "answer") |
| } |
|
|
|
|
| def parse_stages(text: str) -> dict | None: |
| """Return {prethink,caption,think,answer} if all four tags present, else None.""" |
| out = {} |
| for t, rx in TAG_RE.items(): |
| m = rx.search(text) |
| if not m: |
| return None |
| out[t] = m.group(1).strip() |
| if not all(out.values()): |
| return None |
| return out |
|
|
|
|
| |
| LEAK_PATTERNS = [ |
| "correct answer is provided", "answer is given", "answer was given", |
| "as given", "as provided", "provided answer", "given answer", |
| "supervision", "you told me", "since the answer is", "we are told", |
| "the correct answer, which is", "verified correct answer", "do not mention", |
| "[supervision", "told that the answer", |
| ] |
|
|
|
|
| def _max_ngram_repeat(text: str, n: int = 4) -> int: |
| toks = text.split() |
| if len(toks) < n: |
| return 0 |
| from collections import Counter |
| grams = Counter(tuple(toks[i:i + n]) for i in range(len(toks) - n + 1)) |
| return max(grams.values()) if grams else 0 |
|
|
|
|
| def quality_ok(stages: dict, reason: list) -> bool: |
| """Lightweight quality floor for cold-start traces. Mutates `reason` with the failure.""" |
| full = " ".join(stages.values()).lower() |
| for p in LEAK_PATTERNS: |
| if p in full: |
| reason.append(f"leak:{p}") |
| return False |
| |
| if len(stages["caption"].split()) < 12: |
| reason.append("short_caption"); return False |
| if len(stages["think"].split()) < 12: |
| reason.append("short_think"); return False |
| if len(stages["prethink"].split()) < 6: |
| reason.append("short_prethink"); return False |
| |
| for k in ("caption", "think"): |
| if _max_ngram_repeat(stages[k], 4) >= 4: |
| reason.append(f"repeat_{k}"); return False |
| return True |
|
|
|
|
| def norm_letter(s: str) -> str: |
| """Normalize an MCQ answer to a bare uppercase letter when possible.""" |
| s = s.strip() |
| m = re.match(r"^\(?([A-Za-z])\)?\b", s) |
| return m.group(1).upper() if m else s.upper() |
|
|
|
|
| |
| |
|
|
|
|
| def load_source_rows(source: str) -> list[dict]: |
| """Load TB rows for a source ('short_qa' | 'long_qa'), only those with a video on disk.""" |
| path = TB_RAW / f"temporalbench_{source}.json" |
| data = json.load(open(path)) |
| rows = [] |
| for r in data: |
| vn = r["video_name"] |
| vabs = TB_VID / vn |
| if not vabs.exists(): |
| continue |
| rows.append({ |
| "idx": r["idx"], |
| "video_name": vn, |
| "video_abs": str(vabs), |
| "question": r["question"], |
| "GT": str(r["GT"]).strip(), |
| "source": source, |
| }) |
| return rows |
|
|
|
|
| def make_video_field(video_abs: str, fps: float, max_frames: int, min_frames: int, |
| nframes: int, max_pixels: int) -> dict: |
| """Build the qwen_vl_utils video dict. fps-mode when nframes<=0 (fps + frame cap).""" |
| field = {"type": "video", "video": "file://" + video_abs, "max_pixels": max_pixels} |
| if nframes and nframes > 0: |
| field.update({"nframes": nframes, "max_frames": nframes, "min_frames": nframes}) |
| else: |
| field.update({"fps": fps, "max_frames": max_frames, "min_frames": min_frames}) |
| return field |
|
|
|
|
| def load_video_frames(video_abs: str, fps: float, max_frames: int, min_frames: int, |
| nframes: int, max_pixels: int): |
| """Decode a video to a list of PIL frames using qwen_vl_utils (same as compute_grounding).""" |
| from qwen_vl_utils import process_vision_info |
|
|
| vfield = make_video_field(video_abs, fps, max_frames, min_frames, nframes, max_pixels) |
| msg = [{ |
| "role": "user", |
| "content": [vfield, {"type": "text", "text": ""}], |
| }] |
| _, video_inputs, video_kwargs = process_vision_info(msg, return_video_kwargs=True) |
| if not video_inputs: |
| return [], {} |
| vid = video_inputs[0] |
| if hasattr(vid, "permute"): |
| import torch |
| if vid.dtype != torch.uint8: |
| vid = vid.clamp(0, 255).to(torch.uint8) |
| frames = [Image.fromarray(vid[i].permute(1, 2, 0).cpu().numpy(), mode="RGB") |
| for i in range(vid.shape[0])] |
| else: |
| arr = np.asarray(vid) |
| frames = [Image.fromarray(arr[i], mode="RGB") for i in range(arr.shape[0])] |
| return frames, video_kwargs |
|
|
|
|
| def build_train_user_text(question: str) -> str: |
| """User prompt — EXACTLY OMNEX-VL sft.py: QUESTION_TEMPLATE only (TYPE_TEMPLATE is |
| defined but never appended in their prepare_dataset). TB questions already carry the |
| 'Answer with the option's letter...' instruction, so no suffix is needed.""" |
| return QUESTION_TEMPLATE.format(Question=question) |
|
|
|
|
| |
|
|
|
|
| def parse_args(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--teacher", default="Qwen/Qwen2.5-VL-72B-Instruct") |
| ap.add_argument("--tp_size", type=int, default=8) |
| ap.add_argument("--gpu_memory_utilization", type=float, default=0.90) |
| ap.add_argument("--max_model_len", type=int, default=16384) |
| ap.add_argument("--fps", type=float, default=2.0, help="frames per second (fps-mode)") |
| ap.add_argument("--max_frames", type=int, default=64, help="cap frames in fps-mode") |
| ap.add_argument("--min_frames", type=int, default=4) |
| ap.add_argument("--nframes", type=int, default=0, help=">0 forces fixed frame count instead of fps-mode") |
| ap.add_argument("--max_pixels", type=int, default=360 * 420) |
| ap.add_argument("--batch_size", type=int, default=48) |
| ap.add_argument("--temperature", type=float, default=0.6) |
| ap.add_argument("--top_p", type=float, default=0.9) |
| ap.add_argument("--max_tokens", type=int, default=1024) |
| ap.add_argument("--target", type=int, default=10000, help="stop once this many valid samples collected") |
| ap.add_argument("--limit", type=int, default=0, help="cap number of source rows attempted (smoke test)") |
| ap.add_argument("--sources", nargs="+", default=["short_qa", "long_qa"], |
| help="process in order; long_qa used as top-up") |
| ap.add_argument("--seed", type=int, default=0) |
| ap.add_argument("--out_jsonl", default=str(ROOT / "data" / "coldstart_cot.raw.jsonl")) |
| ap.add_argument("--out_json", default=str(ROOT / "data" / "coldstart_cot_10k.json")) |
| return ap.parse_args() |
|
|
|
|
| def main() -> int: |
| args = parse_args() |
| rng = np.random.default_rng(args.seed) |
|
|
| |
| all_rows: list[dict] = [] |
| for src in args.sources: |
| sr = load_source_rows(src) |
| all_rows.extend(sr) |
| print(f"[gen] source {src}: {len(sr)} rows with on-disk video", flush=True) |
| |
| |
| |
| from collections import defaultdict |
| by_vid: dict[str, list[dict]] = defaultdict(list) |
| for r in all_rows: |
| by_vid[r["video_name"]].append(r) |
| vids = list(by_vid.keys()) |
| rng.shuffle(vids) |
| rows = [r for v in vids for r in by_vid[v]] |
| if args.limit: |
| rows = rows[: args.limit] |
| print(f"[gen] total candidate rows: {len(rows)} across {len(vids)} videos " |
| f"(~{len(rows)/max(len(vids),1):.1f} Q/video) | target valid: {args.target}", flush=True) |
|
|
| |
| out_jsonl = Path(args.out_jsonl) |
| done_idx: set[str] = set() |
| n_valid = 0 |
| if out_jsonl.exists(): |
| for line in out_jsonl.open(): |
| try: |
| r = json.loads(line) |
| except Exception: |
| continue |
| done_idx.add(r["meta"]["idx"]) |
| if r.get("valid"): |
| n_valid += 1 |
| print(f"[gen] resume: {len(done_idx)} already processed, {n_valid} valid", flush=True) |
| rows = [r for r in rows if r["idx"] not in done_idx] |
|
|
| |
| print(f"[gen] loading vLLM teacher {args.teacher} (TP={args.tp_size})", flush=True) |
| from vllm import LLM, SamplingParams |
| from transformers import AutoProcessor |
|
|
| processor = AutoProcessor.from_pretrained(args.teacher, trust_remote_code=True) |
| llm = LLM( |
| model=args.teacher, |
| tensor_parallel_size=args.tp_size, |
| max_model_len=args.max_model_len, |
| gpu_memory_utilization=args.gpu_memory_utilization, |
| trust_remote_code=True, |
| limit_mm_per_prompt={"image": 0, "video": 1}, |
| enforce_eager=False, |
| dtype="bfloat16", |
| enable_prefix_caching=True, |
| ) |
| sp = SamplingParams(temperature=args.temperature, top_p=args.top_p, max_tokens=args.max_tokens) |
|
|
| frame_cache: dict[str, tuple] = {} |
| fout = out_jsonl.open("a") |
| t0 = time.time() |
| n_attempt = 0 |
|
|
| def flush_batch(batch: list[dict]): |
| nonlocal n_valid, n_attempt |
| |
| reqs, metas = [], [] |
| for b in batch: |
| vn = b["video_name"] |
| if vn not in frame_cache: |
| try: |
| frame_cache[vn] = load_video_frames( |
| b["video_abs"], args.fps, args.max_frames, args.min_frames, |
| args.nframes, args.max_pixels) |
| except Exception as e: |
| frame_cache[vn] = ([], {}) |
| frames, vkw = frame_cache[vn] |
| if not frames: |
| fout.write(json.dumps({"valid": False, "reason": "decode", "meta": b}) + "\n") |
| continue |
| gen_user = build_train_user_text(b["question"]) + GEN_GT_HINT.format(GT=b["GT"]) |
| msgs = [ |
| {"role": "system", "content": [{"type": "text", "text": SYSTEM_MESSAGE}]}, |
| {"role": "user", "content": [{"type": "video"}, {"type": "text", "text": gen_user}]}, |
| ] |
| text = processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) |
| reqs.append({"prompt": text, "multi_modal_data": {"video": frames}, "mm_processor_kwargs": vkw}) |
| metas.append(b) |
| if not reqs: |
| return |
| try: |
| outs = llm.generate(reqs, sampling_params=sp, use_tqdm=False) |
| except Exception as e: |
| for b in metas: |
| fout.write(json.dumps({"valid": False, "reason": f"vllm:{e}", "meta": b}) + "\n") |
| return |
| for out, b in zip(outs, metas): |
| n_attempt += 1 |
| gen = out.outputs[0].text |
| stages = parse_stages(gen) |
| reason: list[str] = [] |
| if stages is None: |
| ok = False; reason.append("no_tags") |
| elif norm_letter(stages["answer"]) != norm_letter(b["GT"]): |
| ok = False; reason.append("answer_mismatch") |
| elif not quality_ok(stages, reason): |
| ok = False |
| else: |
| ok = True |
| rec = {"valid": bool(ok), "meta": b, "gen": gen} |
| if ok: |
| rec["stages"] = stages |
| n_valid += 1 |
| else: |
| rec["reason"] = ";".join(reason) or "unknown" |
| fout.write(json.dumps(rec, ensure_ascii=False) + "\n") |
| fout.flush() |
|
|
| for i in range(0, len(rows), args.batch_size): |
| if n_valid >= args.target: |
| print(f"[gen] reached target {args.target}, stopping.", flush=True) |
| break |
| flush_batch(rows[i: i + args.batch_size]) |
| rate = n_valid / max(1, n_attempt) |
| el = time.time() - t0 |
| print(f"[gen] attempted={n_attempt} valid={n_valid} ({rate:.1%}) " |
| f"elapsed={el/60:.1f}m {n_attempt/max(el,1):.2f} it/s", flush=True) |
|
|
| fout.close() |
| compile_json(out_jsonl, Path(args.out_json), args.target) |
| return 0 |
|
|
|
|
| def compile_json(raw_jsonl: Path, out_json: Path, target: int): |
| """Turn valid raw records into OMNEX sft.py training schema.""" |
| samples = [] |
| for line in raw_jsonl.open(): |
| try: |
| r = json.loads(line) |
| except Exception: |
| continue |
| if not r.get("valid"): |
| continue |
| b = r["meta"] |
| s = r["stages"] |
| process_and_answer = ( |
| f"<prethink>{s['prethink']}</prethink>\n" |
| f"<caption>{s['caption']}</caption>\n" |
| f"<think>{s['think']}</think>\n" |
| f"<answer>{s['answer']}</answer>" |
| ) |
| samples.append({ |
| "problem": b["question"], |
| "data_type": "video", |
| "path": b["video_abs"], |
| "process_and_answer": process_and_answer, |
| "meta": {"idx": b["idx"], "source": b["source"], "GT": b["GT"]}, |
| }) |
| if len(samples) >= target: |
| break |
| json.dump(samples, open(out_json, "w"), ensure_ascii=False, indent=1) |
| print(f"[gen] wrote {len(samples)} samples -> {out_json}", flush=True) |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |
|
|