forensics-grpo / code /scripts /_test_prober_oneoff.py
sdzt's picture
Add source code
33569f9 verified
Raw
History Blame Contribute Delete
1.75 kB
"""Tiny standalone test to reproduce 'weight must be 2-D' from probe_batch.
Loads ONE cached video_inputs, calls slice_video_by_time, then runs
prober.probe_batch with a single probe. Prints the full traceback.
"""
import json
import os
import sys
import traceback
import torch
sys.path.insert(0, "/mnt/local-fast/zhangt/forensics_grpo")
from src.open_r1.binary_prober import BinaryProber, slice_video_by_time
CACHE_DIR = "/mnt/local-fast/zhangt/forensics_grpo_cache_uniform3584_fps2.0/test/fcvg/0DVVD+2.80=14.30=charades@test_delete@0DVVD@287@fcvg"
MODEL_PATH = "/mnt/local-fast/zhangt/Qwen2.5-VL-7B-Instruct"
vi = torch.load(os.path.join(CACHE_DIR, "video_inputs.pt"), map_location="cpu", weights_only=False)
with open(os.path.join(CACHE_DIR, "video_kwargs.json")) as f:
vk = json.load(f)
print(f"video_inputs type: {type(vi)}")
if isinstance(vi, list):
print(f" outer list len: {len(vi)}")
vi = vi[0]
print(f" tensor shape: {getattr(vi, 'shape', None)} dtype: {getattr(vi, 'dtype', None)}")
print(f"video_kwargs: {vk}")
fps = vk.get("fps")
if isinstance(fps, list):
fps = fps[0]
fps = float(fps)
print(f"effective fps: {fps}")
clip = slice_video_by_time(vi, fps, 0.0, 4.0)
print(f"sliced clip shape: {getattr(clip, 'shape', None)} dtype: {getattr(clip, 'dtype', None)}")
print("\n=== Loading prober ===")
os.environ.setdefault("LOCAL_RANK", "0")
prober = BinaryProber(model_path=MODEL_PATH)
print(f"yes_token_id={prober.yes_token_id} no_token_id={prober.no_token_id}")
print("\n=== Calling probe_batch ===")
try:
out = prober.probe_batch(
[clip], [fps],
["Watch the following short video clip. Is it internally coherent?"],
)
print("OUTPUT:", out)
except Exception:
traceback.print_exc()