opd_zt / scripts /compute_grounding.py
sdzt's picture
Add files using upload-large-folder tool
bf46e5d verified
Raw
History Blame Contribute Delete
12.1 kB
"""Phase 1 motivation script: compute per-token grounding score g_t.
g_t(token) = log p_T(token | video, prefix) - log p_T(token | video_perturbed, prefix)
We run the 72B teacher twice on (prompt + SFT response) -- once on the real video and
once on a frame-count-preserving perturbed video -- using vLLM's prompt_logprobs to
recover the per-token log-prob under both conditioning. g_t is then a per-token tensor.
This is OFFLINE: it does not depend on a trained student. It is the motivation figure
for CD-OPD (does the teacher's response really partition into "grounded" vs
"free-rider" tokens). Run it AFTER vanilla OPD has released the GPUs.
Outputs a parquet with one row per (sample, token):
sample_idx | data_source | token_pos | token_id | token_text |
logp_v | logp_v_perturbed | g_t | response_text
"""
from __future__ import annotations
import argparse
import json
import os
import sys
import time
from pathlib import Path
from typing import Any
import numpy as np
import pandas as pd
from PIL import Image
ROOT = Path("/mnt/local-fast/opd_zt")
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser()
p.add_argument("--input", default=str(ROOT / "data" / "sft_5k.parquet"))
p.add_argument("--output", default=str(ROOT / "data" / "grounding_g_t.parquet"))
p.add_argument("--teacher", default="Qwen/Qwen2.5-VL-72B-Instruct")
p.add_argument("--tp_size", type=int, default=4)
p.add_argument("--num_samples", type=int, default=500)
p.add_argument("--seed", type=int, default=0)
p.add_argument(
"--perturbation",
choices=["black_frames", "shuffle", "mean_frame"],
default="black_frames",
)
p.add_argument("--max_model_len", type=int, default=9216)
p.add_argument("--gpu_memory_utilization", type=float, default=0.80)
p.add_argument("--checkpoint_every", type=int, default=20)
return p.parse_args()
def perturb_video_frames(frames: list[Image.Image], perturbation: str) -> list[Image.Image]:
if perturbation == "black_frames":
black = Image.new(frames[0].mode, frames[0].size, color=0)
return [black.copy() for _ in frames]
if perturbation == "shuffle":
idx = list(range(len(frames)))
np.random.shuffle(idx)
return [frames[i] for i in idx]
if perturbation == "mean_frame":
arrs = np.stack([np.asarray(f, dtype=np.float32) for f in frames], axis=0)
mean = arrs.mean(axis=0).astype(np.uint8)
mean_img = Image.fromarray(mean, mode=frames[0].mode)
return [mean_img.copy() for _ in frames]
raise ValueError(perturbation)
def load_video_frames(video_field: dict) -> list[Image.Image]:
"""Decode a video using the same settings as build_sft_dataset.make_video_dict."""
# video_field example: {"video": "file:///.../X.mp4", "max_frames": 32, "min_frames": 32,
# "nframes": 32, "max_pixels": 151200}
from qwen_vl_utils import process_vision_info
msg = [{
"role": "user",
"content": [
{
"type": "video",
"video": video_field["video"],
"max_frames": int(video_field.get("max_frames", 32)),
"min_frames": int(video_field.get("min_frames", 32)),
"nframes": int(video_field.get("nframes", 32)),
"max_pixels": int(video_field.get("max_pixels", 360 * 420)),
},
{"type": "text", "text": ""},
],
}]
_, video_inputs, video_kwargs = process_vision_info(msg, return_video_kwargs=True)
if video_inputs is None or len(video_inputs) == 0:
return [], {}
# video_inputs is a list of (T, C, H, W) tensors -- one per video. We have one video.
vid = video_inputs[0]
# Convert to PIL frames so we can perturb / pass list-of-PIL to vLLM.
if hasattr(vid, "permute"): # torch tensor
# (T, C, H, W) [0..255] uint8 -> PIL
import torch # local import
if vid.dtype != torch.uint8:
vid = vid.clamp(0, 255).to(torch.uint8)
frames = [
Image.fromarray(vid[i].permute(1, 2, 0).cpu().numpy(), mode="RGB")
for i in range(vid.shape[0])
]
else: # numpy
# (T, H, W, C) uint8
arr = np.asarray(vid)
frames = [Image.fromarray(arr[i], mode="RGB") for i in range(arr.shape[0])]
return frames, video_kwargs
def build_chat_text(processor, user_text: str, assistant_text: str) -> str:
"""Render the prompt+response as a single chat string (no generation prompt)."""
msgs = [
{"role": "user", "content": [
{"type": "video"}, # placeholder; vLLM consumes multi_modal_data separately
{"type": "text", "text": user_text},
]},
{"role": "assistant", "content": assistant_text},
]
return processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False)
def extract_response_logprobs(
prompt_logprobs: list[dict | None],
prompt_ids: list[int],
response_start: int,
) -> tuple[list[float], list[int]]:
"""Pull out per-response-token logprobs from a vLLM PromptLogprobs list.
vLLM's prompt_logprobs[i] is the logprob distribution conditioned on tokens [0..i-1]
over candidates at position i. The actually-realized token at position i is
prompt_ids[i]. We want logprob of that realized token.
"""
out_logp: list[float] = []
out_ids: list[int] = []
for pos in range(response_start, len(prompt_ids)):
slot = prompt_logprobs[pos] if pos < len(prompt_logprobs) else None
if slot is None:
continue
tok = prompt_ids[pos]
entry = slot.get(tok) if isinstance(slot, dict) else None
if entry is None:
continue
# vLLM Logprob object exposes .logprob, but in some versions it's a dict.
lp = getattr(entry, "logprob", None)
if lp is None and isinstance(entry, dict):
lp = entry.get("logprob")
if lp is None:
continue
out_logp.append(float(lp))
out_ids.append(int(tok))
return out_logp, out_ids
def main() -> int:
args = parse_args()
rng = np.random.default_rng(args.seed)
# --- load data ---
df = pd.read_parquet(args.input)
n_total = len(df)
if args.num_samples >= n_total:
sample_idx = np.arange(n_total)
else:
sample_idx = rng.choice(n_total, size=args.num_samples, replace=False)
sample_idx = np.sort(sample_idx)
print(f"[grounding] sampling {len(sample_idx)} / {n_total} rows", flush=True)
# --- load model & processor ---
print(f"[grounding] loading vLLM teacher: {args.teacher}", flush=True)
from vllm import LLM, SamplingParams
from transformers import AutoProcessor
processor = AutoProcessor.from_pretrained(args.teacher, trust_remote_code=True)
llm = LLM(
model=args.teacher,
tensor_parallel_size=args.tp_size,
max_model_len=args.max_model_len,
gpu_memory_utilization=args.gpu_memory_utilization,
trust_remote_code=True,
limit_mm_per_prompt={"image": 0, "video": 1},
enforce_eager=False,
dtype="bfloat16",
)
sp = SamplingParams(
max_tokens=1,
temperature=1.0,
prompt_logprobs=0, # return just realized-token logprobs
)
# --- iterate ---
out_path = Path(args.output)
out_path.parent.mkdir(parents=True, exist_ok=True)
tmp_path = out_path.with_suffix(out_path.suffix + ".part")
rows: list[dict] = []
t0 = time.time()
for i, idx in enumerate(sample_idx):
row = df.iloc[int(idx)]
# prompt is ndarray of 1 dict
prompt_msgs = list(row["prompt"])
# The user prompt content already contains "<video>\n" prefix in our pipeline;
# strip it because we render the video placeholder via chat template.
user_text = prompt_msgs[0]["content"].replace("<video>\n", "").replace("<video>", "").strip()
response_text = str(row["response"])
video_field = dict(row["videos"][0])
data_source = str(row["data_source"])
# --- decode + perturb video ---
try:
frames, video_kwargs = load_video_frames(video_field)
except Exception as e:
print(f"[grounding] skip idx={idx} (video decode failed: {e})", flush=True)
continue
if not frames:
print(f"[grounding] skip idx={idx} (no frames)", flush=True)
continue
perturbed_frames = perturb_video_frames(frames, args.perturbation)
# --- render chat to text ---
text = build_chat_text(processor, user_text, response_text)
# Find where the assistant response starts in the rendered text. Qwen2.5-VL chat
# template puts assistant content after "<|im_start|>assistant\n".
marker = "<|im_start|>assistant\n"
if marker not in text:
print(f"[grounding] skip idx={idx} (no assistant marker)", flush=True)
continue
prefix_text = text[: text.rindex(marker) + len(marker)]
# --- run vLLM twice ---
try:
out_v = llm.generate(
[{
"prompt": text,
"multi_modal_data": {"video": frames},
"mm_processor_kwargs": video_kwargs,
}],
sampling_params=sp,
use_tqdm=False,
)
out_vp = llm.generate(
[{
"prompt": text,
"multi_modal_data": {"video": perturbed_frames},
"mm_processor_kwargs": video_kwargs,
}],
sampling_params=sp,
use_tqdm=False,
)
except Exception as e:
print(f"[grounding] skip idx={idx} (vllm failed: {e})", flush=True)
continue
prompt_ids = out_v[0].prompt_token_ids
plp_v = out_v[0].prompt_logprobs or []
plp_vp = out_vp[0].prompt_logprobs or []
# Identify response token boundary by re-tokenizing prefix.
prefix_ids = processor.tokenizer(prefix_text, add_special_tokens=False)["input_ids"]
response_start = len(prefix_ids)
logp_v, tok_ids_v = extract_response_logprobs(plp_v, prompt_ids, response_start)
logp_vp, tok_ids_vp = extract_response_logprobs(plp_vp, prompt_ids, response_start)
if len(logp_v) != len(logp_vp) or tok_ids_v != tok_ids_vp:
print(
f"[grounding] skip idx={idx} (logprob length mismatch v={len(logp_v)} vp={len(logp_vp)})",
flush=True,
)
continue
for pos, (tid, lv, lvp) in enumerate(zip(tok_ids_v, logp_v, logp_vp, strict=True)):
tok_text = processor.tokenizer.decode([tid])
rows.append({
"sample_idx": int(idx),
"data_source": data_source,
"token_pos": pos,
"token_id": tid,
"token_text": tok_text,
"logp_v": lv,
"logp_v_perturbed": lvp,
"g_t": lv - lvp,
"response_len": len(tok_ids_v),
"perturbation": args.perturbation,
})
if (i + 1) % 10 == 0:
dt = time.time() - t0
eta = dt / (i + 1) * (len(sample_idx) - i - 1)
print(
f"[grounding] {i+1}/{len(sample_idx)} samples rows={len(rows)} "
f"elapsed={dt/60:.1f}min ETA={eta/60:.1f}min",
flush=True,
)
if (i + 1) % args.checkpoint_every == 0 and rows:
pd.DataFrame(rows).to_parquet(tmp_path, index=False)
if not rows:
print("[grounding] no rows produced; exiting non-zero", flush=True)
return 2
pd.DataFrame(rows).to_parquet(out_path, index=False)
if tmp_path.exists():
tmp_path.unlink()
print(f"[grounding] wrote {len(rows)} rows -> {out_path}", flush=True)
return 0
if __name__ == "__main__":
sys.exit(main())