"""Frozen-reference binary probing infrastructure (R1 / R3 reward). Loads a frozen Qwen2.5-VL once per process and exposes a batched API: prober = get_prober() results = prober.probe_batch(video_clips, fps_list, questions) # results[i] = (P(yes | clip_i, question_i), P(no | clip_i, question_i)) Plus a `slice_video_by_time` helper for cutting Qwen-cached video tensors by time range. The frozen reference avoids reward hacking: the policy cannot shift the prober's answers during training. Environment: FORENSICS_PROBE_MODEL path to frozen Qwen2.5-VL checkpoint (default: same Qwen2.5-VL-7B-Instruct path as policy) FORENSICS_PROBE_DEVICE override device (default: cuda:{LOCAL_RANK}) FORENSICS_PROBE_MAX_PIXELS processor.max_pixels (default: 3584 * 28*28) FORENSICS_PROBE_MIN_PIXELS processor.min_pixels (default: 16 * 28*28) """ from __future__ import annotations import os import threading from typing import List, Optional, Tuple import contextlib import traceback import numpy as np import torch from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration @contextlib.contextmanager def _no_deepspeed_zero3(): """Temporarily disable HF transformers' DeepSpeed ZeRO-3 integration so a standalone frozen model can be loaded WITHOUT having its weights partitioned by the policy's deepspeed engine. Needed when the prober is instantiated inside a training process where the policy is under ZeRO-3.""" import importlib try: ds_mod = importlib.import_module("transformers.integrations.deepspeed") except Exception: yield return saved_ref = getattr(ds_mod, "_hf_deepspeed_config_weak_ref", None) try: ds_mod._hf_deepspeed_config_weak_ref = None yield finally: ds_mod._hf_deepspeed_config_weak_ref = saved_ref _INSTANCE: Optional["BinaryProber"] = None _INSTANCE_LOCK = threading.Lock() def _local_rank_device() -> str: rank = int(os.environ.get("LOCAL_RANK", "0")) return f"cuda:{rank}" class BinaryProber: """Frozen Qwen2.5-VL prober for binary yes/no questions over short clips.""" def __init__( self, model_path: str, device: Optional[str] = None, dtype: torch.dtype = torch.bfloat16, max_pixels: Optional[int] = None, min_pixels: Optional[int] = None, ): self.device = device or os.environ.get( "FORENSICS_PROBE_DEVICE", _local_rank_device() ) self.dtype = dtype if max_pixels is None: max_pixels = int(os.environ.get( "FORENSICS_PROBE_MAX_PIXELS", str(3584 * 28 * 28) )) if min_pixels is None: min_pixels = int(os.environ.get( "FORENSICS_PROBE_MIN_PIXELS", str(16 * 28 * 28) )) with _no_deepspeed_zero3(): self.processor = AutoProcessor.from_pretrained( model_path, max_pixels=max_pixels, min_pixels=min_pixels ) self.tokenizer = self.processor.tokenizer self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( model_path, torch_dtype=dtype, attn_implementation="flash_attention_2", ).to(self.device).eval() for p in self.model.parameters(): p.requires_grad_(False) # The chat template ends in `assistant\n` so the first generated token # is the literal word — typically tokenized with a leading space. self.yes_token_id = self._pick_token_id(("yes", " yes", "Yes", " Yes")) self.no_token_id = self._pick_token_id(("no", " no", "No", " No")) def _pick_token_id(self, variants: Tuple[str, ...]) -> int: """Pick the first variant that tokenises to exactly one token.""" for v in variants: ids = self.tokenizer.encode(v, add_special_tokens=False) if len(ids) == 1: return ids[0] # Fallback: first token of the no-space lowercase variant. return self.tokenizer.encode(variants[0], add_special_tokens=False)[0] def _build_chat_text(self, question: str) -> str: messages = [ { "role": "user", "content": [ {"type": "video"}, { "type": "text", "text": question + "\n\nAnswer with a single word: yes or no.", }, ], }, ] return self.processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) @torch.no_grad() def probe_batch( self, video_clips: List[torch.Tensor], fps_list: List[float], questions: List[str], ) -> List[Tuple[float, float]]: """Run probes; return [(P(yes), P(no)), ...] per probe. `video_clips[i]` must be a (T, C, H, W) tensor with T >= 2 (Qwen2.5-VL temporal_patch_size=2 requires an even frame count). All elements are forwarded in a single batch — caller should chunk by GPU memory. """ if not video_clips: return [] prompts_text = [self._build_chat_text(q) for q in questions] with _no_deepspeed_zero3(): inputs = self.processor( text=prompts_text, videos=video_clips, fps=fps_list, padding=True, return_tensors="pt", padding_side="left", add_special_tokens=False, ) inputs = { k: (v.to(self.device) if hasattr(v, "to") else v) for k, v in inputs.items() } with _no_deepspeed_zero3(): outputs = self.model(**inputs, use_cache=False) # `logits[:, -1, :]` corresponds to the next predicted token, i.e. the # first word of the assistant answer in this chat template. last_logits = outputs.logits[:, -1, :].float() yes_l = last_logits[:, self.yes_token_id] no_l = last_logits[:, self.no_token_id] # Renormalise over the 2-class subspace. m = torch.maximum(yes_l, no_l) z = torch.log(torch.exp(yes_l - m) + torch.exp(no_l - m)) + m p_yes = torch.exp(yes_l - z).cpu().numpy() p_no = torch.exp(no_l - z).cpu().numpy() return [(float(y), float(n)) for y, n in zip(p_yes, p_no)] def get_prober() -> BinaryProber: """Process-wide singleton. Lazy-loaded on first call.""" global _INSTANCE if _INSTANCE is None: with _INSTANCE_LOCK: if _INSTANCE is None: model_path = os.environ.get("FORENSICS_PROBE_MODEL") if not model_path: raise RuntimeError( "FORENSICS_PROBE_MODEL is not set. Point it at a " "frozen Qwen2.5-VL checkpoint." ) _INSTANCE = BinaryProber(model_path=model_path) return _INSTANCE def slice_video_by_time( video_input, fps: float, start_s: float, end_s: float, min_frames: int = 4, ) -> Optional[torch.Tensor]: """Return frames in [start_s, end_s] as (T, C, H, W). None if too short. Handles Qwen2.5-VL temporal_patch_size=2 constraint by enforcing even frame counts (snaps boundary outward when needed). """ if not torch.is_tensor(video_input): video_input = torch.as_tensor(video_input) if video_input.ndim != 4: # Defensive: some pipelines return list-of-frames; try to stack. return None T = video_input.shape[0] start_f = max(0, int(round(start_s * fps))) end_f = min(T, int(round(end_s * fps))) if end_f <= start_f: return None if end_f - start_f < min_frames: deficit = min_frames - (end_f - start_f) end_f = min(T, end_f + deficit) start_f = max(0, end_f - min_frames) if end_f - start_f < min_frames: return None # Even-frame constraint for temporal patchification. if (end_f - start_f) % 2 != 0: if end_f < T: end_f += 1 elif start_f > 0: start_f -= 1 else: return None return video_input[start_f:end_f].contiguous()