"""Phase 1 motivation script: compute per-token grounding score g_t. g_t(token) = log p_T(token | video, prefix) - log p_T(token | video_perturbed, prefix) We run the 72B teacher twice on (prompt + SFT response) -- once on the real video and once on a frame-count-preserving perturbed video -- using vLLM's prompt_logprobs to recover the per-token log-prob under both conditioning. g_t is then a per-token tensor. This is OFFLINE: it does not depend on a trained student. It is the motivation figure for CD-OPD (does the teacher's response really partition into "grounded" vs "free-rider" tokens). Run it AFTER vanilla OPD has released the GPUs. Outputs a parquet with one row per (sample, token): sample_idx | data_source | token_pos | token_id | token_text | logp_v | logp_v_perturbed | g_t | response_text """ from __future__ import annotations import argparse import json import os import sys import time from pathlib import Path from typing import Any import numpy as np import pandas as pd from PIL import Image ROOT = Path("/mnt/local-fast/opd_zt") def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser() p.add_argument("--input", default=str(ROOT / "data" / "sft_5k.parquet")) p.add_argument("--output", default=str(ROOT / "data" / "grounding_g_t.parquet")) p.add_argument("--teacher", default="Qwen/Qwen2.5-VL-72B-Instruct") p.add_argument("--tp_size", type=int, default=4) p.add_argument("--num_samples", type=int, default=500) p.add_argument("--seed", type=int, default=0) p.add_argument( "--perturbation", choices=["black_frames", "shuffle", "mean_frame"], default="black_frames", ) p.add_argument("--max_model_len", type=int, default=9216) p.add_argument("--gpu_memory_utilization", type=float, default=0.80) p.add_argument("--checkpoint_every", type=int, default=20) return p.parse_args() def perturb_video_frames(frames: list[Image.Image], perturbation: str) -> list[Image.Image]: if perturbation == "black_frames": black = Image.new(frames[0].mode, frames[0].size, color=0) return [black.copy() for _ in frames] if perturbation == "shuffle": idx = list(range(len(frames))) np.random.shuffle(idx) return [frames[i] for i in idx] if perturbation == "mean_frame": arrs = np.stack([np.asarray(f, dtype=np.float32) for f in frames], axis=0) mean = arrs.mean(axis=0).astype(np.uint8) mean_img = Image.fromarray(mean, mode=frames[0].mode) return [mean_img.copy() for _ in frames] raise ValueError(perturbation) def load_video_frames(video_field: dict) -> list[Image.Image]: """Decode a video using the same settings as build_sft_dataset.make_video_dict.""" # video_field example: {"video": "file:///.../X.mp4", "max_frames": 32, "min_frames": 32, # "nframes": 32, "max_pixels": 151200} from qwen_vl_utils import process_vision_info msg = [{ "role": "user", "content": [ { "type": "video", "video": video_field["video"], "max_frames": int(video_field.get("max_frames", 32)), "min_frames": int(video_field.get("min_frames", 32)), "nframes": int(video_field.get("nframes", 32)), "max_pixels": int(video_field.get("max_pixels", 360 * 420)), }, {"type": "text", "text": ""}, ], }] _, video_inputs, video_kwargs = process_vision_info(msg, return_video_kwargs=True) if video_inputs is None or len(video_inputs) == 0: return [], {} # video_inputs is a list of (T, C, H, W) tensors -- one per video. We have one video. vid = video_inputs[0] # Convert to PIL frames so we can perturb / pass list-of-PIL to vLLM. if hasattr(vid, "permute"): # torch tensor # (T, C, H, W) [0..255] uint8 -> PIL import torch # local import if vid.dtype != torch.uint8: vid = vid.clamp(0, 255).to(torch.uint8) frames = [ Image.fromarray(vid[i].permute(1, 2, 0).cpu().numpy(), mode="RGB") for i in range(vid.shape[0]) ] else: # numpy # (T, H, W, C) uint8 arr = np.asarray(vid) frames = [Image.fromarray(arr[i], mode="RGB") for i in range(arr.shape[0])] return frames, video_kwargs def build_chat_text(processor, user_text: str, assistant_text: str) -> str: """Render the prompt+response as a single chat string (no generation prompt).""" msgs = [ {"role": "user", "content": [ {"type": "video"}, # placeholder; vLLM consumes multi_modal_data separately {"type": "text", "text": user_text}, ]}, {"role": "assistant", "content": assistant_text}, ] return processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False) def extract_response_logprobs( prompt_logprobs: list[dict | None], prompt_ids: list[int], response_start: int, ) -> tuple[list[float], list[int]]: """Pull out per-response-token logprobs from a vLLM PromptLogprobs list. vLLM's prompt_logprobs[i] is the logprob distribution conditioned on tokens [0..i-1] over candidates at position i. The actually-realized token at position i is prompt_ids[i]. We want logprob of that realized token. """ out_logp: list[float] = [] out_ids: list[int] = [] for pos in range(response_start, len(prompt_ids)): slot = prompt_logprobs[pos] if pos < len(prompt_logprobs) else None if slot is None: continue tok = prompt_ids[pos] entry = slot.get(tok) if isinstance(slot, dict) else None if entry is None: continue # vLLM Logprob object exposes .logprob, but in some versions it's a dict. lp = getattr(entry, "logprob", None) if lp is None and isinstance(entry, dict): lp = entry.get("logprob") if lp is None: continue out_logp.append(float(lp)) out_ids.append(int(tok)) return out_logp, out_ids def main() -> int: args = parse_args() rng = np.random.default_rng(args.seed) # --- load data --- df = pd.read_parquet(args.input) n_total = len(df) if args.num_samples >= n_total: sample_idx = np.arange(n_total) else: sample_idx = rng.choice(n_total, size=args.num_samples, replace=False) sample_idx = np.sort(sample_idx) print(f"[grounding] sampling {len(sample_idx)} / {n_total} rows", flush=True) # --- load model & processor --- print(f"[grounding] loading vLLM teacher: {args.teacher}", flush=True) from vllm import LLM, SamplingParams from transformers import AutoProcessor processor = AutoProcessor.from_pretrained(args.teacher, trust_remote_code=True) llm = LLM( model=args.teacher, tensor_parallel_size=args.tp_size, max_model_len=args.max_model_len, gpu_memory_utilization=args.gpu_memory_utilization, trust_remote_code=True, limit_mm_per_prompt={"image": 0, "video": 1}, enforce_eager=False, dtype="bfloat16", ) sp = SamplingParams( max_tokens=1, temperature=1.0, prompt_logprobs=0, # return just realized-token logprobs ) # --- iterate --- out_path = Path(args.output) out_path.parent.mkdir(parents=True, exist_ok=True) tmp_path = out_path.with_suffix(out_path.suffix + ".part") rows: list[dict] = [] t0 = time.time() for i, idx in enumerate(sample_idx): row = df.iloc[int(idx)] # prompt is ndarray of 1 dict prompt_msgs = list(row["prompt"]) # The user prompt content already contains "