""" Build video cold-start CoT data in the OMNEX-VL ("Ground What You See") format. Pipeline (mirrors LLaVA-CoT construction, but for video): - Source: TemporalBench short_qa (MCQ, clean GT letter). Optional top-up: long_qa. - Teacher: Qwen2.5-VL-72B-Instruct via vLLM, sees the real video. - The GT answer is PROVIDED to the teacher (LLaVA-CoT style) so it writes a plausible 4-stage trace that lands on GT. GT is used only at generation time; it is NOT part of the stored training prompt. - Output 4 stages with tags: . - Rejection sampling: keep only rows where all four tags parse AND the extracted matches GT (case-insensitive letter match). Output (OMNEX sft.py schema): data/coldstart_cot_10k.json -> list of {problem, data_type, path, process_and_answer, meta} also writes incrementally to data/coldstart_cot.raw.jsonl (resume-safe). Usage: python scripts/build_coldstart_cot.py --limit 8 # smoke test python scripts/build_coldstart_cot.py --target 10000 # full run """ from __future__ import annotations import argparse import json import os import re import sys import time from pathlib import Path import numpy as np from PIL import Image ROOT = Path("/mnt/local-fast/opd_zt") TB_RAW = ROOT / "data" / "raw" / "tb" TB_VID = ROOT / "data" / "videos" / "tb" # --------------------------------------------------------------------------- # OMNEX-VL prompt templates (verbatim from train/.../cold-start/train/sft.py) SYSTEM_MESSAGE = "You are a helpful assistant" QUESTION_TEMPLATE = ( "{Question}\n" "Please carefully analyze the pictures (or videos) and problems according to the following requirements" "In tags, carefully analyze the problem and briefly explain the steps to explain the problem and the key thinking direction of reasoning the problem" "In tags, Please describe the image carefully, paying special attention to the details related to the problem and the reasoning direction of solving the problem" "In tags, outline a step-by-step thought process you would use to solve the problem based on the image" "In tags, give the final answer in a direct format, and it must match the correct answer exactly." "Please sort out the output in the format of '...\n...\n...\n...' according to the above requirements" ) TYPE_TEMPLATE = { "multiple choice": " Please provide only the single option letter (e.g., A, B, C, D, etc.) within the tags.", "numerical": " Please provide the numerical value (e.g., 42 or 3.14) within the tags.", } # Extra instruction shown ONLY to the teacher at generation time (stripped from # the stored training prompt). Provides the gold answer, LLaVA-CoT style. GEN_GT_HINT = ( "\n[Supervision only — do NOT mention that the answer was given] " "The verified correct answer is: {GT}. " "Write the // so the reasoning genuinely supports this, " "and make exactly {GT}." ) TAG_RE = { t: re.compile(rf"<{t}>(.*?)", re.DOTALL) for t in ("prethink", "caption", "think", "answer") } def parse_stages(text: str) -> dict | None: """Return {prethink,caption,think,answer} if all four tags present, else None.""" out = {} for t, rx in TAG_RE.items(): m = rx.search(text) if not m: return None out[t] = m.group(1).strip() if not all(out.values()): return None return out # phrases that reveal the GT was handed to the teacher (LLaVA-CoT leakage) LEAK_PATTERNS = [ "correct answer is provided", "answer is given", "answer was given", "as given", "as provided", "provided answer", "given answer", "supervision", "you told me", "since the answer is", "we are told", "the correct answer, which is", "verified correct answer", "do not mention", "[supervision", "told that the answer", ] def _max_ngram_repeat(text: str, n: int = 4) -> int: toks = text.split() if len(toks) < n: return 0 from collections import Counter grams = Counter(tuple(toks[i:i + n]) for i in range(len(toks) - n + 1)) return max(grams.values()) if grams else 0 def quality_ok(stages: dict, reason: list) -> bool: """Lightweight quality floor for cold-start traces. Mutates `reason` with the failure.""" full = " ".join(stages.values()).lower() for p in LEAK_PATTERNS: if p in full: reason.append(f"leak:{p}") return False # min substance per stage (word counts) if len(stages["caption"].split()) < 12: reason.append("short_caption"); return False if len(stages["think"].split()) < 12: reason.append("short_think"); return False if len(stages["prethink"].split()) < 6: reason.append("short_prethink"); return False # repetition / degeneration for k in ("caption", "think"): if _max_ngram_repeat(stages[k], 4) >= 4: reason.append(f"repeat_{k}"); return False return True def norm_letter(s: str) -> str: """Normalize an MCQ answer to a bare uppercase letter when possible.""" s = s.strip() m = re.match(r"^\(?([A-Za-z])\)?\b", s) return m.group(1).upper() if m else s.upper() # --------------------------------------------------------------------------- # data loading def load_source_rows(source: str) -> list[dict]: """Load TB rows for a source ('short_qa' | 'long_qa'), only those with a video on disk.""" path = TB_RAW / f"temporalbench_{source}.json" data = json.load(open(path)) rows = [] for r in data: vn = r["video_name"] vabs = TB_VID / vn if not vabs.exists(): continue rows.append({ "idx": r["idx"], "video_name": vn, "video_abs": str(vabs), "question": r["question"], "GT": str(r["GT"]).strip(), "source": source, }) return rows def make_video_field(video_abs: str, fps: float, max_frames: int, min_frames: int, nframes: int, max_pixels: int) -> dict: """Build the qwen_vl_utils video dict. fps-mode when nframes<=0 (fps + frame cap).""" field = {"type": "video", "video": "file://" + video_abs, "max_pixels": max_pixels} if nframes and nframes > 0: field.update({"nframes": nframes, "max_frames": nframes, "min_frames": nframes}) else: field.update({"fps": fps, "max_frames": max_frames, "min_frames": min_frames}) return field def load_video_frames(video_abs: str, fps: float, max_frames: int, min_frames: int, nframes: int, max_pixels: int): """Decode a video to a list of PIL frames using qwen_vl_utils (same as compute_grounding).""" from qwen_vl_utils import process_vision_info vfield = make_video_field(video_abs, fps, max_frames, min_frames, nframes, max_pixels) msg = [{ "role": "user", "content": [vfield, {"type": "text", "text": ""}], }] _, video_inputs, video_kwargs = process_vision_info(msg, return_video_kwargs=True) if not video_inputs: return [], {} vid = video_inputs[0] if hasattr(vid, "permute"): import torch if vid.dtype != torch.uint8: vid = vid.clamp(0, 255).to(torch.uint8) frames = [Image.fromarray(vid[i].permute(1, 2, 0).cpu().numpy(), mode="RGB") for i in range(vid.shape[0])] else: arr = np.asarray(vid) frames = [Image.fromarray(arr[i], mode="RGB") for i in range(arr.shape[0])] return frames, video_kwargs def build_train_user_text(question: str) -> str: """User prompt — EXACTLY OMNEX-VL sft.py: QUESTION_TEMPLATE only (TYPE_TEMPLATE is defined but never appended in their prepare_dataset). TB questions already carry the 'Answer with the option's letter...' instruction, so no suffix is needed.""" return QUESTION_TEMPLATE.format(Question=question) # --------------------------------------------------------------------------- def parse_args(): ap = argparse.ArgumentParser() ap.add_argument("--teacher", default="Qwen/Qwen2.5-VL-72B-Instruct") ap.add_argument("--tp_size", type=int, default=8) ap.add_argument("--gpu_memory_utilization", type=float, default=0.90) ap.add_argument("--max_model_len", type=int, default=16384) ap.add_argument("--fps", type=float, default=2.0, help="frames per second (fps-mode)") ap.add_argument("--max_frames", type=int, default=64, help="cap frames in fps-mode") ap.add_argument("--min_frames", type=int, default=4) ap.add_argument("--nframes", type=int, default=0, help=">0 forces fixed frame count instead of fps-mode") ap.add_argument("--max_pixels", type=int, default=360 * 420) ap.add_argument("--batch_size", type=int, default=48) ap.add_argument("--temperature", type=float, default=0.6) ap.add_argument("--top_p", type=float, default=0.9) ap.add_argument("--max_tokens", type=int, default=1024) ap.add_argument("--target", type=int, default=10000, help="stop once this many valid samples collected") ap.add_argument("--limit", type=int, default=0, help="cap number of source rows attempted (smoke test)") ap.add_argument("--sources", nargs="+", default=["short_qa", "long_qa"], help="process in order; long_qa used as top-up") ap.add_argument("--seed", type=int, default=0) ap.add_argument("--out_jsonl", default=str(ROOT / "data" / "coldstart_cot.raw.jsonl")) ap.add_argument("--out_json", default=str(ROOT / "data" / "coldstart_cot_10k.json")) return ap.parse_args() def main() -> int: args = parse_args() rng = np.random.default_rng(args.seed) # --- gather source rows --- all_rows: list[dict] = [] for src in args.sources: sr = load_source_rows(src) all_rows.extend(sr) print(f"[gen] source {src}: {len(sr)} rows with on-disk video", flush=True) # Group questions by video and emit same-video questions consecutively so the # shared [system +