| """Phase 1 motivation script: compute per-token grounding score g_t. |
| |
| g_t(token) = log p_T(token | video, prefix) - log p_T(token | video_perturbed, prefix) |
| |
| We run the 72B teacher twice on (prompt + SFT response) -- once on the real video and |
| once on a frame-count-preserving perturbed video -- using vLLM's prompt_logprobs to |
| recover the per-token log-prob under both conditioning. g_t is then a per-token tensor. |
| |
| This is OFFLINE: it does not depend on a trained student. It is the motivation figure |
| for CD-OPD (does the teacher's response really partition into "grounded" vs |
| "free-rider" tokens). Run it AFTER vanilla OPD has released the GPUs. |
| |
| Outputs a parquet with one row per (sample, token): |
| sample_idx | data_source | token_pos | token_id | token_text | |
| logp_v | logp_v_perturbed | g_t | response_text |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| import sys |
| import time |
| from pathlib import Path |
| from typing import Any |
|
|
| import numpy as np |
| import pandas as pd |
| from PIL import Image |
|
|
| ROOT = Path("/mnt/local-fast/opd_zt") |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| p = argparse.ArgumentParser() |
| p.add_argument("--input", default=str(ROOT / "data" / "sft_5k.parquet")) |
| p.add_argument("--output", default=str(ROOT / "data" / "grounding_g_t.parquet")) |
| p.add_argument("--teacher", default="Qwen/Qwen2.5-VL-72B-Instruct") |
| p.add_argument("--tp_size", type=int, default=4) |
| p.add_argument("--num_samples", type=int, default=500) |
| p.add_argument("--seed", type=int, default=0) |
| p.add_argument( |
| "--perturbation", |
| choices=["black_frames", "shuffle", "mean_frame"], |
| default="black_frames", |
| ) |
| p.add_argument("--max_model_len", type=int, default=9216) |
| p.add_argument("--gpu_memory_utilization", type=float, default=0.80) |
| p.add_argument("--checkpoint_every", type=int, default=20) |
| return p.parse_args() |
|
|
|
|
| def perturb_video_frames(frames: list[Image.Image], perturbation: str) -> list[Image.Image]: |
| if perturbation == "black_frames": |
| black = Image.new(frames[0].mode, frames[0].size, color=0) |
| return [black.copy() for _ in frames] |
| if perturbation == "shuffle": |
| idx = list(range(len(frames))) |
| np.random.shuffle(idx) |
| return [frames[i] for i in idx] |
| if perturbation == "mean_frame": |
| arrs = np.stack([np.asarray(f, dtype=np.float32) for f in frames], axis=0) |
| mean = arrs.mean(axis=0).astype(np.uint8) |
| mean_img = Image.fromarray(mean, mode=frames[0].mode) |
| return [mean_img.copy() for _ in frames] |
| raise ValueError(perturbation) |
|
|
|
|
| def load_video_frames(video_field: dict) -> list[Image.Image]: |
| """Decode a video using the same settings as build_sft_dataset.make_video_dict.""" |
| |
| |
| from qwen_vl_utils import process_vision_info |
|
|
| msg = [{ |
| "role": "user", |
| "content": [ |
| { |
| "type": "video", |
| "video": video_field["video"], |
| "max_frames": int(video_field.get("max_frames", 32)), |
| "min_frames": int(video_field.get("min_frames", 32)), |
| "nframes": int(video_field.get("nframes", 32)), |
| "max_pixels": int(video_field.get("max_pixels", 360 * 420)), |
| }, |
| {"type": "text", "text": ""}, |
| ], |
| }] |
| _, video_inputs, video_kwargs = process_vision_info(msg, return_video_kwargs=True) |
| if video_inputs is None or len(video_inputs) == 0: |
| return [], {} |
| |
| vid = video_inputs[0] |
| |
| if hasattr(vid, "permute"): |
| |
| import torch |
|
|
| if vid.dtype != torch.uint8: |
| vid = vid.clamp(0, 255).to(torch.uint8) |
| frames = [ |
| Image.fromarray(vid[i].permute(1, 2, 0).cpu().numpy(), mode="RGB") |
| for i in range(vid.shape[0]) |
| ] |
| else: |
| |
| arr = np.asarray(vid) |
| frames = [Image.fromarray(arr[i], mode="RGB") for i in range(arr.shape[0])] |
| return frames, video_kwargs |
|
|
|
|
| def build_chat_text(processor, user_text: str, assistant_text: str) -> str: |
| """Render the prompt+response as a single chat string (no generation prompt).""" |
| msgs = [ |
| {"role": "user", "content": [ |
| {"type": "video"}, |
| {"type": "text", "text": user_text}, |
| ]}, |
| {"role": "assistant", "content": assistant_text}, |
| ] |
| return processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False) |
|
|
|
|
| def extract_response_logprobs( |
| prompt_logprobs: list[dict | None], |
| prompt_ids: list[int], |
| response_start: int, |
| ) -> tuple[list[float], list[int]]: |
| """Pull out per-response-token logprobs from a vLLM PromptLogprobs list. |
| |
| vLLM's prompt_logprobs[i] is the logprob distribution conditioned on tokens [0..i-1] |
| over candidates at position i. The actually-realized token at position i is |
| prompt_ids[i]. We want logprob of that realized token. |
| """ |
| out_logp: list[float] = [] |
| out_ids: list[int] = [] |
| for pos in range(response_start, len(prompt_ids)): |
| slot = prompt_logprobs[pos] if pos < len(prompt_logprobs) else None |
| if slot is None: |
| continue |
| tok = prompt_ids[pos] |
| entry = slot.get(tok) if isinstance(slot, dict) else None |
| if entry is None: |
| continue |
| |
| lp = getattr(entry, "logprob", None) |
| if lp is None and isinstance(entry, dict): |
| lp = entry.get("logprob") |
| if lp is None: |
| continue |
| out_logp.append(float(lp)) |
| out_ids.append(int(tok)) |
| return out_logp, out_ids |
|
|
|
|
| def main() -> int: |
| args = parse_args() |
| rng = np.random.default_rng(args.seed) |
|
|
| |
| df = pd.read_parquet(args.input) |
| n_total = len(df) |
| if args.num_samples >= n_total: |
| sample_idx = np.arange(n_total) |
| else: |
| sample_idx = rng.choice(n_total, size=args.num_samples, replace=False) |
| sample_idx = np.sort(sample_idx) |
| print(f"[grounding] sampling {len(sample_idx)} / {n_total} rows", flush=True) |
|
|
| |
| print(f"[grounding] loading vLLM teacher: {args.teacher}", flush=True) |
| from vllm import LLM, SamplingParams |
| from transformers import AutoProcessor |
|
|
| processor = AutoProcessor.from_pretrained(args.teacher, trust_remote_code=True) |
| llm = LLM( |
| model=args.teacher, |
| tensor_parallel_size=args.tp_size, |
| max_model_len=args.max_model_len, |
| gpu_memory_utilization=args.gpu_memory_utilization, |
| trust_remote_code=True, |
| limit_mm_per_prompt={"image": 0, "video": 1}, |
| enforce_eager=False, |
| dtype="bfloat16", |
| ) |
|
|
| sp = SamplingParams( |
| max_tokens=1, |
| temperature=1.0, |
| prompt_logprobs=0, |
| ) |
|
|
| |
| out_path = Path(args.output) |
| out_path.parent.mkdir(parents=True, exist_ok=True) |
| tmp_path = out_path.with_suffix(out_path.suffix + ".part") |
| rows: list[dict] = [] |
| t0 = time.time() |
|
|
| for i, idx in enumerate(sample_idx): |
| row = df.iloc[int(idx)] |
| |
| prompt_msgs = list(row["prompt"]) |
| |
| |
| user_text = prompt_msgs[0]["content"].replace("<video>\n", "").replace("<video>", "").strip() |
| response_text = str(row["response"]) |
| video_field = dict(row["videos"][0]) |
| data_source = str(row["data_source"]) |
|
|
| |
| try: |
| frames, video_kwargs = load_video_frames(video_field) |
| except Exception as e: |
| print(f"[grounding] skip idx={idx} (video decode failed: {e})", flush=True) |
| continue |
| if not frames: |
| print(f"[grounding] skip idx={idx} (no frames)", flush=True) |
| continue |
| perturbed_frames = perturb_video_frames(frames, args.perturbation) |
|
|
| |
| text = build_chat_text(processor, user_text, response_text) |
| |
| |
| marker = "<|im_start|>assistant\n" |
| if marker not in text: |
| print(f"[grounding] skip idx={idx} (no assistant marker)", flush=True) |
| continue |
| prefix_text = text[: text.rindex(marker) + len(marker)] |
|
|
| |
| try: |
| out_v = llm.generate( |
| [{ |
| "prompt": text, |
| "multi_modal_data": {"video": frames}, |
| "mm_processor_kwargs": video_kwargs, |
| }], |
| sampling_params=sp, |
| use_tqdm=False, |
| ) |
| out_vp = llm.generate( |
| [{ |
| "prompt": text, |
| "multi_modal_data": {"video": perturbed_frames}, |
| "mm_processor_kwargs": video_kwargs, |
| }], |
| sampling_params=sp, |
| use_tqdm=False, |
| ) |
| except Exception as e: |
| print(f"[grounding] skip idx={idx} (vllm failed: {e})", flush=True) |
| continue |
|
|
| prompt_ids = out_v[0].prompt_token_ids |
| plp_v = out_v[0].prompt_logprobs or [] |
| plp_vp = out_vp[0].prompt_logprobs or [] |
|
|
| |
| prefix_ids = processor.tokenizer(prefix_text, add_special_tokens=False)["input_ids"] |
| response_start = len(prefix_ids) |
|
|
| logp_v, tok_ids_v = extract_response_logprobs(plp_v, prompt_ids, response_start) |
| logp_vp, tok_ids_vp = extract_response_logprobs(plp_vp, prompt_ids, response_start) |
| if len(logp_v) != len(logp_vp) or tok_ids_v != tok_ids_vp: |
| print( |
| f"[grounding] skip idx={idx} (logprob length mismatch v={len(logp_v)} vp={len(logp_vp)})", |
| flush=True, |
| ) |
| continue |
|
|
| for pos, (tid, lv, lvp) in enumerate(zip(tok_ids_v, logp_v, logp_vp, strict=True)): |
| tok_text = processor.tokenizer.decode([tid]) |
| rows.append({ |
| "sample_idx": int(idx), |
| "data_source": data_source, |
| "token_pos": pos, |
| "token_id": tid, |
| "token_text": tok_text, |
| "logp_v": lv, |
| "logp_v_perturbed": lvp, |
| "g_t": lv - lvp, |
| "response_len": len(tok_ids_v), |
| "perturbation": args.perturbation, |
| }) |
|
|
| if (i + 1) % 10 == 0: |
| dt = time.time() - t0 |
| eta = dt / (i + 1) * (len(sample_idx) - i - 1) |
| print( |
| f"[grounding] {i+1}/{len(sample_idx)} samples rows={len(rows)} " |
| f"elapsed={dt/60:.1f}min ETA={eta/60:.1f}min", |
| flush=True, |
| ) |
| if (i + 1) % args.checkpoint_every == 0 and rows: |
| pd.DataFrame(rows).to_parquet(tmp_path, index=False) |
|
|
| if not rows: |
| print("[grounding] no rows produced; exiting non-zero", flush=True) |
| return 2 |
| pd.DataFrame(rows).to_parquet(out_path, index=False) |
| if tmp_path.exists(): |
| tmp_path.unlink() |
| print(f"[grounding] wrote {len(rows)} rows -> {out_path}", flush=True) |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |
|
|