forensics-grpo / code /src /open_r1 /binary_prober.py
sdzt's picture
Add source code
33569f9 verified
Raw
History Blame Contribute Delete
8.43 kB
"""Frozen-reference binary probing infrastructure (R1 / R3 reward).
Loads a frozen Qwen2.5-VL once per process and exposes a batched API:
prober = get_prober()
results = prober.probe_batch(video_clips, fps_list, questions)
# results[i] = (P(yes | clip_i, question_i), P(no | clip_i, question_i))
Plus a `slice_video_by_time` helper for cutting Qwen-cached video tensors by
time range. The frozen reference avoids reward hacking: the policy cannot
shift the prober's answers during training.
Environment:
FORENSICS_PROBE_MODEL path to frozen Qwen2.5-VL checkpoint
(default: same Qwen2.5-VL-7B-Instruct path
as policy)
FORENSICS_PROBE_DEVICE override device (default: cuda:{LOCAL_RANK})
FORENSICS_PROBE_MAX_PIXELS processor.max_pixels (default: 3584 * 28*28)
FORENSICS_PROBE_MIN_PIXELS processor.min_pixels (default: 16 * 28*28)
"""
from __future__ import annotations
import os
import threading
from typing import List, Optional, Tuple
import contextlib
import traceback
import numpy as np
import torch
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
@contextlib.contextmanager
def _no_deepspeed_zero3():
"""Temporarily disable HF transformers' DeepSpeed ZeRO-3 integration so a
standalone frozen model can be loaded WITHOUT having its weights partitioned
by the policy's deepspeed engine. Needed when the prober is instantiated
inside a training process where the policy is under ZeRO-3."""
import importlib
try:
ds_mod = importlib.import_module("transformers.integrations.deepspeed")
except Exception:
yield
return
saved_ref = getattr(ds_mod, "_hf_deepspeed_config_weak_ref", None)
try:
ds_mod._hf_deepspeed_config_weak_ref = None
yield
finally:
ds_mod._hf_deepspeed_config_weak_ref = saved_ref
_INSTANCE: Optional["BinaryProber"] = None
_INSTANCE_LOCK = threading.Lock()
def _local_rank_device() -> str:
rank = int(os.environ.get("LOCAL_RANK", "0"))
return f"cuda:{rank}"
class BinaryProber:
"""Frozen Qwen2.5-VL prober for binary yes/no questions over short clips."""
def __init__(
self,
model_path: str,
device: Optional[str] = None,
dtype: torch.dtype = torch.bfloat16,
max_pixels: Optional[int] = None,
min_pixels: Optional[int] = None,
):
self.device = device or os.environ.get(
"FORENSICS_PROBE_DEVICE", _local_rank_device()
)
self.dtype = dtype
if max_pixels is None:
max_pixels = int(os.environ.get(
"FORENSICS_PROBE_MAX_PIXELS", str(3584 * 28 * 28)
))
if min_pixels is None:
min_pixels = int(os.environ.get(
"FORENSICS_PROBE_MIN_PIXELS", str(16 * 28 * 28)
))
with _no_deepspeed_zero3():
self.processor = AutoProcessor.from_pretrained(
model_path, max_pixels=max_pixels, min_pixels=min_pixels
)
self.tokenizer = self.processor.tokenizer
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=dtype,
attn_implementation="flash_attention_2",
).to(self.device).eval()
for p in self.model.parameters():
p.requires_grad_(False)
# The chat template ends in `assistant\n` so the first generated token
# is the literal word — typically tokenized with a leading space.
self.yes_token_id = self._pick_token_id(("yes", " yes", "Yes", " Yes"))
self.no_token_id = self._pick_token_id(("no", " no", "No", " No"))
def _pick_token_id(self, variants: Tuple[str, ...]) -> int:
"""Pick the first variant that tokenises to exactly one token."""
for v in variants:
ids = self.tokenizer.encode(v, add_special_tokens=False)
if len(ids) == 1:
return ids[0]
# Fallback: first token of the no-space lowercase variant.
return self.tokenizer.encode(variants[0], add_special_tokens=False)[0]
def _build_chat_text(self, question: str) -> str:
messages = [
{
"role": "user",
"content": [
{"type": "video"},
{
"type": "text",
"text": question + "\n\nAnswer with a single word: yes or no.",
},
],
},
]
return self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
@torch.no_grad()
def probe_batch(
self,
video_clips: List[torch.Tensor],
fps_list: List[float],
questions: List[str],
) -> List[Tuple[float, float]]:
"""Run probes; return [(P(yes), P(no)), ...] per probe.
`video_clips[i]` must be a (T, C, H, W) tensor with T >= 2 (Qwen2.5-VL
temporal_patch_size=2 requires an even frame count). All elements are
forwarded in a single batch — caller should chunk by GPU memory.
"""
if not video_clips:
return []
prompts_text = [self._build_chat_text(q) for q in questions]
with _no_deepspeed_zero3():
inputs = self.processor(
text=prompts_text,
videos=video_clips,
fps=fps_list,
padding=True,
return_tensors="pt",
padding_side="left",
add_special_tokens=False,
)
inputs = {
k: (v.to(self.device) if hasattr(v, "to") else v)
for k, v in inputs.items()
}
with _no_deepspeed_zero3():
outputs = self.model(**inputs, use_cache=False)
# `logits[:, -1, :]` corresponds to the next predicted token, i.e. the
# first word of the assistant answer in this chat template.
last_logits = outputs.logits[:, -1, :].float()
yes_l = last_logits[:, self.yes_token_id]
no_l = last_logits[:, self.no_token_id]
# Renormalise over the 2-class subspace.
m = torch.maximum(yes_l, no_l)
z = torch.log(torch.exp(yes_l - m) + torch.exp(no_l - m)) + m
p_yes = torch.exp(yes_l - z).cpu().numpy()
p_no = torch.exp(no_l - z).cpu().numpy()
return [(float(y), float(n)) for y, n in zip(p_yes, p_no)]
def get_prober() -> BinaryProber:
"""Process-wide singleton. Lazy-loaded on first call."""
global _INSTANCE
if _INSTANCE is None:
with _INSTANCE_LOCK:
if _INSTANCE is None:
model_path = os.environ.get("FORENSICS_PROBE_MODEL")
if not model_path:
raise RuntimeError(
"FORENSICS_PROBE_MODEL is not set. Point it at a "
"frozen Qwen2.5-VL checkpoint."
)
_INSTANCE = BinaryProber(model_path=model_path)
return _INSTANCE
def slice_video_by_time(
video_input,
fps: float,
start_s: float,
end_s: float,
min_frames: int = 4,
) -> Optional[torch.Tensor]:
"""Return frames in [start_s, end_s] as (T, C, H, W). None if too short.
Handles Qwen2.5-VL temporal_patch_size=2 constraint by enforcing even
frame counts (snaps boundary outward when needed).
"""
if not torch.is_tensor(video_input):
video_input = torch.as_tensor(video_input)
if video_input.ndim != 4:
# Defensive: some pipelines return list-of-frames; try to stack.
return None
T = video_input.shape[0]
start_f = max(0, int(round(start_s * fps)))
end_f = min(T, int(round(end_s * fps)))
if end_f <= start_f:
return None
if end_f - start_f < min_frames:
deficit = min_frames - (end_f - start_f)
end_f = min(T, end_f + deficit)
start_f = max(0, end_f - min_frames)
if end_f - start_f < min_frames:
return None
# Even-frame constraint for temporal patchification.
if (end_f - start_f) % 2 != 0:
if end_f < T:
end_f += 1
elif start_f > 0:
start_f -= 1
else:
return None
return video_input[start_f:end_f].contiguous()