File size: 7,866 Bytes
f885e53 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 | # vrthinker
A video reward model that compares two videos against a text prompt and outputs per-dimension preferences:
**TA** (Text Alignment), **MQ** (Motion Quality), **VQ** (Visual Quality), **OA** (Overall).
Each label is one of `1` (Video 1 wins), `2` (Video 2 wins), `0` (tie).
The model reasons step-by-step and may call a `select_frames` tool to request additional frames from the videos
before committing to an answer.
## Install
```bash
pip install torch transformers accelerate pillow opencv-python
```
## Inference
Save the snippet below as `infer.py`, then:
```bash
python infer.py --video1 path/to/v1.mp4 --video2 path/to/v2.mp4 \
--prompt "A robot rides a unicorn across a rainbow bridge."
```
```python
# infer.py
import argparse, json, re
from pathlib import Path
import cv2
import torch
from PIL import Image
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
MODEL_DIR = str(Path(__file__).resolve().parent) # the dir containing this README
FRAMES_PER_VIDEO = 128
INITIAL_PER_VIDEO = 4
MAX_TURNS = 6
MAX_FRAMES_PER_CALL = 12
IMAGE_SIDE = 448
SYSTEM_PROMPT = """Task Description:
Your task is to compare two videos generated based on the same text prompt by analyzing their frames in detail and provide an overall judgment along with a judgment for each evaluation dimension.
The provided frames are downsampled from these videos:
- Video 1: First four input frames.
- Video 2: Next four input frames.
Evaluation Dimensions:
1. Text Alignment (TA): How faithfully each video reflects the text prompt.
2. Visual Quality (VQ): Aesthetics, artifacts, blurriness, distortion, color, resolution, flickering.
3. Motion Quality (MQ): Smoothness, jitter, unnatural motion, temporal consistency.
4. Overall Assessment (OA): Holistic judgment across the above.
Frames and Analysis Rules:
- 8 sampled frames are provided initially (4 per video), evenly downsampled from 128 frames per video. The first 4 are Video 1, the next 4 are Video 2.
- Each video has 128 frames (indices 0-127). To inspect more frames, call select_frames with the indices you need; the tool retrieves the same indices from both videos symmetrically.
- Tool returns are paired: for [i, j, k] you get (v1[i], v2[i], v1[j], v2[j], v1[k], v2[k]). Use this pairing to compare the same moment across both videos.
- Each tool call accepts at most 12 indices.
Format Requirement:
1. <Snapshot></Snapshot> — summarize useful visual details after receiving frames.
2. <Think></Think> — reasoning.
3. <Answer></Answer> — final judgment.
Label semantics: 1 = Video 1 better, 2 = Video 2 better, 0 = tie.
Examples:
<Answer>TA=1, VQ=1, MQ=0, OA=1</Answer>
Tool call format:
When you want to inspect more frames, emit a tool call inside <tool_call></tool_call> tags:
<tool_call>{"name": "select_frames", "arguments": {"frame_indices": [10, 30, 60, 90]}}</tool_call>
"""
def extract_frames(video_path: str, indices: list[int]) -> list[Image.Image]:
"""Return PIL frames at the given indices, evenly mapped over the video's actual length."""
cap = cv2.VideoCapture(video_path)
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
out: list[Image.Image] = []
for idx in indices:
# map idx in [0, FRAMES_PER_VIDEO) -> real frame in [0, total)
real = min(int(idx / FRAMES_PER_VIDEO * total), total - 1)
cap.set(cv2.CAP_PROP_POS_FRAMES, real)
ok, frame = cap.read()
if not ok:
continue
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
img = Image.fromarray(frame).resize((IMAGE_SIDE, IMAGE_SIDE))
out.append(img)
cap.release()
return out
def initial_frames(v1: str, v2: str) -> list[Image.Image]:
idxs = [int(FRAMES_PER_VIDEO * (i + 0.5) / INITIAL_PER_VIDEO) for i in range(INITIAL_PER_VIDEO)]
return extract_frames(v1, idxs) + extract_frames(v2, idxs)
def tool_frames(v1: str, v2: str, indices: list[int]) -> list[Image.Image]:
indices = indices[:MAX_FRAMES_PER_CALL]
out: list[Image.Image] = []
for i in indices:
out += extract_frames(v1, [i])
out += extract_frames(v2, [i])
return out
def parse_tool_call(text: str) -> dict | None:
m = re.search(r"<tool_call>\s*(\{.*?\})\s*</tool_call>", text, re.DOTALL)
if not m:
return None
try:
obj = json.loads(m.group(1))
return obj.get("arguments", {})
except json.JSONDecodeError:
return None
def parse_answer(text: str) -> dict | None:
m = re.search(r"<Answer>(.*?)</Answer>", text, re.DOTALL | re.IGNORECASE)
if not m:
return None
body = m.group(1)
return {d: int(re.search(rf"{d}\s*=\s*(\d)", body).group(1))
for d in ("TA", "MQ", "VQ", "OA")
if re.search(rf"{d}\s*=\s*(\d)", body)}
@torch.inference_mode()
def run(video1: str, video2: str, prompt: str) -> dict:
processor = AutoProcessor.from_pretrained(MODEL_DIR, trust_remote_code=True)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_DIR, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True
).eval()
images = initial_frames(video1, video2)
user_text = (
f"Compare the two videos generated from the following prompt and evaluate them "
f"across Text Alignment (TA), Motion Quality (MQ), Visual Quality (VQ), and "
f"Overall Assessment (OA).\n\nPrompt: {prompt}\n\n"
f"The first 4 images are uniformly sampled from Video 1, and the next 4 are from "
f"Video 2. Each video has 128 frames (indices 0-127). "
f"Use the select_frames tool to request additional frames if needed."
)
messages = [
{"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
{"role": "user", "content": [{"type": "image"}] * len(images)
+ [{"type": "text", "text": user_text}]},
]
for turn in range(MAX_TURNS):
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=[text], images=images, return_tensors="pt", padding=True).to(model.device)
output_ids = model.generate(**inputs, max_new_tokens=2048, do_sample=False, temperature=0.0)
reply = processor.batch_decode(
output_ids[:, inputs.input_ids.shape[1]:], skip_special_tokens=True
)[0]
messages.append({"role": "assistant", "content": [{"type": "text", "text": reply}]})
answer = parse_answer(reply)
if answer:
return answer
call = parse_tool_call(reply)
if not call or "frame_indices" not in call:
return parse_answer(reply) or {"TA": None, "MQ": None, "VQ": None, "OA": None}
new_imgs = tool_frames(video1, video2, call["frame_indices"])
images += new_imgs
messages.append({
"role": "user",
"content": [{"type": "image"}] * len(new_imgs)
+ [{"type": "text",
"text": f"<tool_response>Retrieved {len(call['frame_indices'])} "
f"frame pairs ({call['frame_indices']}) symmetrically from both "
f"videos.</tool_response>"}],
})
return {"TA": None, "MQ": None, "VQ": None, "OA": None}
if __name__ == "__main__":
p = argparse.ArgumentParser()
p.add_argument("--video1", required=True)
p.add_argument("--video2", required=True)
p.add_argument("--prompt", required=True)
args = p.parse_args()
print(json.dumps(run(args.video1, args.video2, args.prompt), indent=2))
```
## Output
```
{
"TA": 1,
"MQ": 0,
"VQ": 2,
"OA": 1
}
```
`1` = Video 1 wins on that dimension, `2` = Video 2 wins, `0` = tie.
## Hardware
Requires ~16 GB GPU memory in bf16. Tested on a single A100/H100.
|