| # vrthinker |
|
|
| A video reward model that compares two videos against a text prompt and outputs per-dimension preferences: |
| **TA** (Text Alignment), **MQ** (Motion Quality), **VQ** (Visual Quality), **OA** (Overall). |
| Each label is one of `1` (Video 1 wins), `2` (Video 2 wins), `0` (tie). |
|
|
| The model reasons step-by-step and may call a `select_frames` tool to request additional frames from the videos |
| before committing to an answer. |
|
|
| ## Install |
|
|
| ```bash |
| pip install torch transformers accelerate pillow opencv-python |
| ``` |
|
|
| ## Inference |
|
|
| Save the snippet below as `infer.py`, then: |
|
|
| ```bash |
| python infer.py --video1 path/to/v1.mp4 --video2 path/to/v2.mp4 \ |
| --prompt "A robot rides a unicorn across a rainbow bridge." |
| ``` |
|
|
| ```python |
| # infer.py |
| import argparse, json, re |
| from pathlib import Path |
| |
| import cv2 |
| import torch |
| from PIL import Image |
| from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration |
| |
| MODEL_DIR = str(Path(__file__).resolve().parent) # the dir containing this README |
| FRAMES_PER_VIDEO = 128 |
| INITIAL_PER_VIDEO = 4 |
| MAX_TURNS = 6 |
| MAX_FRAMES_PER_CALL = 12 |
| IMAGE_SIDE = 448 |
| |
| SYSTEM_PROMPT = """Task Description: |
| Your task is to compare two videos generated based on the same text prompt by analyzing their frames in detail and provide an overall judgment along with a judgment for each evaluation dimension. |
| |
| The provided frames are downsampled from these videos: |
| - Video 1: First four input frames. |
| - Video 2: Next four input frames. |
| |
| Evaluation Dimensions: |
| 1. Text Alignment (TA): How faithfully each video reflects the text prompt. |
| 2. Visual Quality (VQ): Aesthetics, artifacts, blurriness, distortion, color, resolution, flickering. |
| 3. Motion Quality (MQ): Smoothness, jitter, unnatural motion, temporal consistency. |
| 4. Overall Assessment (OA): Holistic judgment across the above. |
| |
| Frames and Analysis Rules: |
| - 8 sampled frames are provided initially (4 per video), evenly downsampled from 128 frames per video. The first 4 are Video 1, the next 4 are Video 2. |
| - Each video has 128 frames (indices 0-127). To inspect more frames, call select_frames with the indices you need; the tool retrieves the same indices from both videos symmetrically. |
| - Tool returns are paired: for [i, j, k] you get (v1[i], v2[i], v1[j], v2[j], v1[k], v2[k]). Use this pairing to compare the same moment across both videos. |
| - Each tool call accepts at most 12 indices. |
| |
| Format Requirement: |
| 1. <Snapshot></Snapshot> — summarize useful visual details after receiving frames. |
| 2. <Think></Think> — reasoning. |
| 3. <Answer></Answer> — final judgment. |
| |
| Label semantics: 1 = Video 1 better, 2 = Video 2 better, 0 = tie. |
| |
| Examples: |
| <Answer>TA=1, VQ=1, MQ=0, OA=1</Answer> |
| |
| Tool call format: |
| When you want to inspect more frames, emit a tool call inside <tool_call></tool_call> tags: |
| <tool_call>{"name": "select_frames", "arguments": {"frame_indices": [10, 30, 60, 90]}}</tool_call> |
| """ |
| |
| |
| def extract_frames(video_path: str, indices: list[int]) -> list[Image.Image]: |
| """Return PIL frames at the given indices, evenly mapped over the video's actual length.""" |
| cap = cv2.VideoCapture(video_path) |
| total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| out: list[Image.Image] = [] |
| for idx in indices: |
| # map idx in [0, FRAMES_PER_VIDEO) -> real frame in [0, total) |
| real = min(int(idx / FRAMES_PER_VIDEO * total), total - 1) |
| cap.set(cv2.CAP_PROP_POS_FRAMES, real) |
| ok, frame = cap.read() |
| if not ok: |
| continue |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| img = Image.fromarray(frame).resize((IMAGE_SIDE, IMAGE_SIDE)) |
| out.append(img) |
| cap.release() |
| return out |
| |
| |
| def initial_frames(v1: str, v2: str) -> list[Image.Image]: |
| idxs = [int(FRAMES_PER_VIDEO * (i + 0.5) / INITIAL_PER_VIDEO) for i in range(INITIAL_PER_VIDEO)] |
| return extract_frames(v1, idxs) + extract_frames(v2, idxs) |
| |
| |
| def tool_frames(v1: str, v2: str, indices: list[int]) -> list[Image.Image]: |
| indices = indices[:MAX_FRAMES_PER_CALL] |
| out: list[Image.Image] = [] |
| for i in indices: |
| out += extract_frames(v1, [i]) |
| out += extract_frames(v2, [i]) |
| return out |
| |
| |
| def parse_tool_call(text: str) -> dict | None: |
| m = re.search(r"<tool_call>\s*(\{.*?\})\s*</tool_call>", text, re.DOTALL) |
| if not m: |
| return None |
| try: |
| obj = json.loads(m.group(1)) |
| return obj.get("arguments", {}) |
| except json.JSONDecodeError: |
| return None |
| |
| |
| def parse_answer(text: str) -> dict | None: |
| m = re.search(r"<Answer>(.*?)</Answer>", text, re.DOTALL | re.IGNORECASE) |
| if not m: |
| return None |
| body = m.group(1) |
| return {d: int(re.search(rf"{d}\s*=\s*(\d)", body).group(1)) |
| for d in ("TA", "MQ", "VQ", "OA") |
| if re.search(rf"{d}\s*=\s*(\d)", body)} |
| |
| |
| @torch.inference_mode() |
| def run(video1: str, video2: str, prompt: str) -> dict: |
| processor = AutoProcessor.from_pretrained(MODEL_DIR, trust_remote_code=True) |
| model = Qwen2_5_VLForConditionalGeneration.from_pretrained( |
| MODEL_DIR, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True |
| ).eval() |
| |
| images = initial_frames(video1, video2) |
| user_text = ( |
| f"Compare the two videos generated from the following prompt and evaluate them " |
| f"across Text Alignment (TA), Motion Quality (MQ), Visual Quality (VQ), and " |
| f"Overall Assessment (OA).\n\nPrompt: {prompt}\n\n" |
| f"The first 4 images are uniformly sampled from Video 1, and the next 4 are from " |
| f"Video 2. Each video has 128 frames (indices 0-127). " |
| f"Use the select_frames tool to request additional frames if needed." |
| ) |
| messages = [ |
| {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]}, |
| {"role": "user", "content": [{"type": "image"}] * len(images) |
| + [{"type": "text", "text": user_text}]}, |
| ] |
| |
| for turn in range(MAX_TURNS): |
| text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| inputs = processor(text=[text], images=images, return_tensors="pt", padding=True).to(model.device) |
| output_ids = model.generate(**inputs, max_new_tokens=2048, do_sample=False, temperature=0.0) |
| reply = processor.batch_decode( |
| output_ids[:, inputs.input_ids.shape[1]:], skip_special_tokens=True |
| )[0] |
| |
| messages.append({"role": "assistant", "content": [{"type": "text", "text": reply}]}) |
| |
| answer = parse_answer(reply) |
| if answer: |
| return answer |
| |
| call = parse_tool_call(reply) |
| if not call or "frame_indices" not in call: |
| return parse_answer(reply) or {"TA": None, "MQ": None, "VQ": None, "OA": None} |
| |
| new_imgs = tool_frames(video1, video2, call["frame_indices"]) |
| images += new_imgs |
| messages.append({ |
| "role": "user", |
| "content": [{"type": "image"}] * len(new_imgs) |
| + [{"type": "text", |
| "text": f"<tool_response>Retrieved {len(call['frame_indices'])} " |
| f"frame pairs ({call['frame_indices']}) symmetrically from both " |
| f"videos.</tool_response>"}], |
| }) |
| |
| return {"TA": None, "MQ": None, "VQ": None, "OA": None} |
| |
| |
| if __name__ == "__main__": |
| p = argparse.ArgumentParser() |
| p.add_argument("--video1", required=True) |
| p.add_argument("--video2", required=True) |
| p.add_argument("--prompt", required=True) |
| args = p.parse_args() |
| print(json.dumps(run(args.video1, args.video2, args.prompt), indent=2)) |
| ``` |
|
|
| ## Output |
|
|
| ``` |
| { |
| "TA": 1, |
| "MQ": 0, |
| "VQ": 2, |
| "OA": 1 |
| } |
| ``` |
|
|
| `1` = Video 1 wins on that dimension, `2` = Video 2 wins, `0` = tie. |
|
|
| ## Hardware |
|
|
| Requires ~16 GB GPU memory in bf16. Tested on a single A100/H100. |
|
|