| """Frozen-reference binary probing infrastructure (R1 / R3 reward). |
| |
| Loads a frozen Qwen2.5-VL once per process and exposes a batched API: |
| |
| prober = get_prober() |
| results = prober.probe_batch(video_clips, fps_list, questions) |
| # results[i] = (P(yes | clip_i, question_i), P(no | clip_i, question_i)) |
| |
| Plus a `slice_video_by_time` helper for cutting Qwen-cached video tensors by |
| time range. The frozen reference avoids reward hacking: the policy cannot |
| shift the prober's answers during training. |
| |
| Environment: |
| FORENSICS_PROBE_MODEL path to frozen Qwen2.5-VL checkpoint |
| (default: same Qwen2.5-VL-7B-Instruct path |
| as policy) |
| FORENSICS_PROBE_DEVICE override device (default: cuda:{LOCAL_RANK}) |
| FORENSICS_PROBE_MAX_PIXELS processor.max_pixels (default: 3584 * 28*28) |
| FORENSICS_PROBE_MIN_PIXELS processor.min_pixels (default: 16 * 28*28) |
| """ |
| from __future__ import annotations |
|
|
| import os |
| import threading |
| from typing import List, Optional, Tuple |
|
|
| import contextlib |
| import traceback |
|
|
| import numpy as np |
| import torch |
| from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration |
|
|
|
|
| @contextlib.contextmanager |
| def _no_deepspeed_zero3(): |
| """Temporarily disable HF transformers' DeepSpeed ZeRO-3 integration so a |
| standalone frozen model can be loaded WITHOUT having its weights partitioned |
| by the policy's deepspeed engine. Needed when the prober is instantiated |
| inside a training process where the policy is under ZeRO-3.""" |
| import importlib |
| try: |
| ds_mod = importlib.import_module("transformers.integrations.deepspeed") |
| except Exception: |
| yield |
| return |
| saved_ref = getattr(ds_mod, "_hf_deepspeed_config_weak_ref", None) |
| try: |
| ds_mod._hf_deepspeed_config_weak_ref = None |
| yield |
| finally: |
| ds_mod._hf_deepspeed_config_weak_ref = saved_ref |
|
|
|
|
| _INSTANCE: Optional["BinaryProber"] = None |
| _INSTANCE_LOCK = threading.Lock() |
|
|
|
|
| def _local_rank_device() -> str: |
| rank = int(os.environ.get("LOCAL_RANK", "0")) |
| return f"cuda:{rank}" |
|
|
|
|
| class BinaryProber: |
| """Frozen Qwen2.5-VL prober for binary yes/no questions over short clips.""" |
|
|
| def __init__( |
| self, |
| model_path: str, |
| device: Optional[str] = None, |
| dtype: torch.dtype = torch.bfloat16, |
| max_pixels: Optional[int] = None, |
| min_pixels: Optional[int] = None, |
| ): |
| self.device = device or os.environ.get( |
| "FORENSICS_PROBE_DEVICE", _local_rank_device() |
| ) |
| self.dtype = dtype |
| if max_pixels is None: |
| max_pixels = int(os.environ.get( |
| "FORENSICS_PROBE_MAX_PIXELS", str(3584 * 28 * 28) |
| )) |
| if min_pixels is None: |
| min_pixels = int(os.environ.get( |
| "FORENSICS_PROBE_MIN_PIXELS", str(16 * 28 * 28) |
| )) |
|
|
| with _no_deepspeed_zero3(): |
| self.processor = AutoProcessor.from_pretrained( |
| model_path, max_pixels=max_pixels, min_pixels=min_pixels |
| ) |
| self.tokenizer = self.processor.tokenizer |
| self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( |
| model_path, |
| torch_dtype=dtype, |
| attn_implementation="flash_attention_2", |
| ).to(self.device).eval() |
| for p in self.model.parameters(): |
| p.requires_grad_(False) |
|
|
| |
| |
| self.yes_token_id = self._pick_token_id(("yes", " yes", "Yes", " Yes")) |
| self.no_token_id = self._pick_token_id(("no", " no", "No", " No")) |
|
|
| def _pick_token_id(self, variants: Tuple[str, ...]) -> int: |
| """Pick the first variant that tokenises to exactly one token.""" |
| for v in variants: |
| ids = self.tokenizer.encode(v, add_special_tokens=False) |
| if len(ids) == 1: |
| return ids[0] |
| |
| return self.tokenizer.encode(variants[0], add_special_tokens=False)[0] |
|
|
| def _build_chat_text(self, question: str) -> str: |
| messages = [ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "video"}, |
| { |
| "type": "text", |
| "text": question + "\n\nAnswer with a single word: yes or no.", |
| }, |
| ], |
| }, |
| ] |
| return self.processor.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=True |
| ) |
|
|
| @torch.no_grad() |
| def probe_batch( |
| self, |
| video_clips: List[torch.Tensor], |
| fps_list: List[float], |
| questions: List[str], |
| ) -> List[Tuple[float, float]]: |
| """Run probes; return [(P(yes), P(no)), ...] per probe. |
| |
| `video_clips[i]` must be a (T, C, H, W) tensor with T >= 2 (Qwen2.5-VL |
| temporal_patch_size=2 requires an even frame count). All elements are |
| forwarded in a single batch — caller should chunk by GPU memory. |
| """ |
| if not video_clips: |
| return [] |
| prompts_text = [self._build_chat_text(q) for q in questions] |
|
|
| with _no_deepspeed_zero3(): |
| inputs = self.processor( |
| text=prompts_text, |
| videos=video_clips, |
| fps=fps_list, |
| padding=True, |
| return_tensors="pt", |
| padding_side="left", |
| add_special_tokens=False, |
| ) |
| inputs = { |
| k: (v.to(self.device) if hasattr(v, "to") else v) |
| for k, v in inputs.items() |
| } |
|
|
| with _no_deepspeed_zero3(): |
| outputs = self.model(**inputs, use_cache=False) |
| |
| |
| last_logits = outputs.logits[:, -1, :].float() |
| yes_l = last_logits[:, self.yes_token_id] |
| no_l = last_logits[:, self.no_token_id] |
| |
| m = torch.maximum(yes_l, no_l) |
| z = torch.log(torch.exp(yes_l - m) + torch.exp(no_l - m)) + m |
| p_yes = torch.exp(yes_l - z).cpu().numpy() |
| p_no = torch.exp(no_l - z).cpu().numpy() |
| return [(float(y), float(n)) for y, n in zip(p_yes, p_no)] |
|
|
|
|
| def get_prober() -> BinaryProber: |
| """Process-wide singleton. Lazy-loaded on first call.""" |
| global _INSTANCE |
| if _INSTANCE is None: |
| with _INSTANCE_LOCK: |
| if _INSTANCE is None: |
| model_path = os.environ.get("FORENSICS_PROBE_MODEL") |
| if not model_path: |
| raise RuntimeError( |
| "FORENSICS_PROBE_MODEL is not set. Point it at a " |
| "frozen Qwen2.5-VL checkpoint." |
| ) |
| _INSTANCE = BinaryProber(model_path=model_path) |
| return _INSTANCE |
|
|
|
|
| def slice_video_by_time( |
| video_input, |
| fps: float, |
| start_s: float, |
| end_s: float, |
| min_frames: int = 4, |
| ) -> Optional[torch.Tensor]: |
| """Return frames in [start_s, end_s] as (T, C, H, W). None if too short. |
| |
| Handles Qwen2.5-VL temporal_patch_size=2 constraint by enforcing even |
| frame counts (snaps boundary outward when needed). |
| """ |
| if not torch.is_tensor(video_input): |
| video_input = torch.as_tensor(video_input) |
| if video_input.ndim != 4: |
| |
| return None |
|
|
| T = video_input.shape[0] |
| start_f = max(0, int(round(start_s * fps))) |
| end_f = min(T, int(round(end_s * fps))) |
| if end_f <= start_f: |
| return None |
|
|
| if end_f - start_f < min_frames: |
| deficit = min_frames - (end_f - start_f) |
| end_f = min(T, end_f + deficit) |
| start_f = max(0, end_f - min_frames) |
| if end_f - start_f < min_frames: |
| return None |
|
|
| |
| if (end_f - start_f) % 2 != 0: |
| if end_f < T: |
| end_f += 1 |
| elif start_f > 0: |
| start_f -= 1 |
| else: |
| return None |
|
|
| return video_input[start_f:end_f].contiguous() |
|
|