File size: 8,425 Bytes
33569f9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 | """Frozen-reference binary probing infrastructure (R1 / R3 reward).
Loads a frozen Qwen2.5-VL once per process and exposes a batched API:
prober = get_prober()
results = prober.probe_batch(video_clips, fps_list, questions)
# results[i] = (P(yes | clip_i, question_i), P(no | clip_i, question_i))
Plus a `slice_video_by_time` helper for cutting Qwen-cached video tensors by
time range. The frozen reference avoids reward hacking: the policy cannot
shift the prober's answers during training.
Environment:
FORENSICS_PROBE_MODEL path to frozen Qwen2.5-VL checkpoint
(default: same Qwen2.5-VL-7B-Instruct path
as policy)
FORENSICS_PROBE_DEVICE override device (default: cuda:{LOCAL_RANK})
FORENSICS_PROBE_MAX_PIXELS processor.max_pixels (default: 3584 * 28*28)
FORENSICS_PROBE_MIN_PIXELS processor.min_pixels (default: 16 * 28*28)
"""
from __future__ import annotations
import os
import threading
from typing import List, Optional, Tuple
import contextlib
import traceback
import numpy as np
import torch
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
@contextlib.contextmanager
def _no_deepspeed_zero3():
"""Temporarily disable HF transformers' DeepSpeed ZeRO-3 integration so a
standalone frozen model can be loaded WITHOUT having its weights partitioned
by the policy's deepspeed engine. Needed when the prober is instantiated
inside a training process where the policy is under ZeRO-3."""
import importlib
try:
ds_mod = importlib.import_module("transformers.integrations.deepspeed")
except Exception:
yield
return
saved_ref = getattr(ds_mod, "_hf_deepspeed_config_weak_ref", None)
try:
ds_mod._hf_deepspeed_config_weak_ref = None
yield
finally:
ds_mod._hf_deepspeed_config_weak_ref = saved_ref
_INSTANCE: Optional["BinaryProber"] = None
_INSTANCE_LOCK = threading.Lock()
def _local_rank_device() -> str:
rank = int(os.environ.get("LOCAL_RANK", "0"))
return f"cuda:{rank}"
class BinaryProber:
"""Frozen Qwen2.5-VL prober for binary yes/no questions over short clips."""
def __init__(
self,
model_path: str,
device: Optional[str] = None,
dtype: torch.dtype = torch.bfloat16,
max_pixels: Optional[int] = None,
min_pixels: Optional[int] = None,
):
self.device = device or os.environ.get(
"FORENSICS_PROBE_DEVICE", _local_rank_device()
)
self.dtype = dtype
if max_pixels is None:
max_pixels = int(os.environ.get(
"FORENSICS_PROBE_MAX_PIXELS", str(3584 * 28 * 28)
))
if min_pixels is None:
min_pixels = int(os.environ.get(
"FORENSICS_PROBE_MIN_PIXELS", str(16 * 28 * 28)
))
with _no_deepspeed_zero3():
self.processor = AutoProcessor.from_pretrained(
model_path, max_pixels=max_pixels, min_pixels=min_pixels
)
self.tokenizer = self.processor.tokenizer
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=dtype,
attn_implementation="flash_attention_2",
).to(self.device).eval()
for p in self.model.parameters():
p.requires_grad_(False)
# The chat template ends in `assistant\n` so the first generated token
# is the literal word — typically tokenized with a leading space.
self.yes_token_id = self._pick_token_id(("yes", " yes", "Yes", " Yes"))
self.no_token_id = self._pick_token_id(("no", " no", "No", " No"))
def _pick_token_id(self, variants: Tuple[str, ...]) -> int:
"""Pick the first variant that tokenises to exactly one token."""
for v in variants:
ids = self.tokenizer.encode(v, add_special_tokens=False)
if len(ids) == 1:
return ids[0]
# Fallback: first token of the no-space lowercase variant.
return self.tokenizer.encode(variants[0], add_special_tokens=False)[0]
def _build_chat_text(self, question: str) -> str:
messages = [
{
"role": "user",
"content": [
{"type": "video"},
{
"type": "text",
"text": question + "\n\nAnswer with a single word: yes or no.",
},
],
},
]
return self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
@torch.no_grad()
def probe_batch(
self,
video_clips: List[torch.Tensor],
fps_list: List[float],
questions: List[str],
) -> List[Tuple[float, float]]:
"""Run probes; return [(P(yes), P(no)), ...] per probe.
`video_clips[i]` must be a (T, C, H, W) tensor with T >= 2 (Qwen2.5-VL
temporal_patch_size=2 requires an even frame count). All elements are
forwarded in a single batch — caller should chunk by GPU memory.
"""
if not video_clips:
return []
prompts_text = [self._build_chat_text(q) for q in questions]
with _no_deepspeed_zero3():
inputs = self.processor(
text=prompts_text,
videos=video_clips,
fps=fps_list,
padding=True,
return_tensors="pt",
padding_side="left",
add_special_tokens=False,
)
inputs = {
k: (v.to(self.device) if hasattr(v, "to") else v)
for k, v in inputs.items()
}
with _no_deepspeed_zero3():
outputs = self.model(**inputs, use_cache=False)
# `logits[:, -1, :]` corresponds to the next predicted token, i.e. the
# first word of the assistant answer in this chat template.
last_logits = outputs.logits[:, -1, :].float()
yes_l = last_logits[:, self.yes_token_id]
no_l = last_logits[:, self.no_token_id]
# Renormalise over the 2-class subspace.
m = torch.maximum(yes_l, no_l)
z = torch.log(torch.exp(yes_l - m) + torch.exp(no_l - m)) + m
p_yes = torch.exp(yes_l - z).cpu().numpy()
p_no = torch.exp(no_l - z).cpu().numpy()
return [(float(y), float(n)) for y, n in zip(p_yes, p_no)]
def get_prober() -> BinaryProber:
"""Process-wide singleton. Lazy-loaded on first call."""
global _INSTANCE
if _INSTANCE is None:
with _INSTANCE_LOCK:
if _INSTANCE is None:
model_path = os.environ.get("FORENSICS_PROBE_MODEL")
if not model_path:
raise RuntimeError(
"FORENSICS_PROBE_MODEL is not set. Point it at a "
"frozen Qwen2.5-VL checkpoint."
)
_INSTANCE = BinaryProber(model_path=model_path)
return _INSTANCE
def slice_video_by_time(
video_input,
fps: float,
start_s: float,
end_s: float,
min_frames: int = 4,
) -> Optional[torch.Tensor]:
"""Return frames in [start_s, end_s] as (T, C, H, W). None if too short.
Handles Qwen2.5-VL temporal_patch_size=2 constraint by enforcing even
frame counts (snaps boundary outward when needed).
"""
if not torch.is_tensor(video_input):
video_input = torch.as_tensor(video_input)
if video_input.ndim != 4:
# Defensive: some pipelines return list-of-frames; try to stack.
return None
T = video_input.shape[0]
start_f = max(0, int(round(start_s * fps)))
end_f = min(T, int(round(end_s * fps)))
if end_f <= start_f:
return None
if end_f - start_f < min_frames:
deficit = min_frames - (end_f - start_f)
end_f = min(T, end_f + deficit)
start_f = max(0, end_f - min_frames)
if end_f - start_f < min_frames:
return None
# Even-frame constraint for temporal patchification.
if (end_f - start_f) % 2 != 0:
if end_f < T:
end_f += 1
elif start_f > 0:
start_f -= 1
else:
return None
return video_input[start_f:end_f].contiguous()
|