File size: 7,866 Bytes
f885e53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
# vrthinker

A video reward model that compares two videos against a text prompt and outputs per-dimension preferences:
**TA** (Text Alignment), **MQ** (Motion Quality), **VQ** (Visual Quality), **OA** (Overall).
Each label is one of `1` (Video 1 wins), `2` (Video 2 wins), `0` (tie).

The model reasons step-by-step and may call a `select_frames` tool to request additional frames from the videos
before committing to an answer.

## Install

```bash
pip install torch transformers accelerate pillow opencv-python
```

## Inference

Save the snippet below as `infer.py`, then:

```bash
python infer.py --video1 path/to/v1.mp4 --video2 path/to/v2.mp4 \
                --prompt "A robot rides a unicorn across a rainbow bridge."
```

```python
# infer.py
import argparse, json, re
from pathlib import Path

import cv2
import torch
from PIL import Image
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration

MODEL_DIR = str(Path(__file__).resolve().parent)        # the dir containing this README
FRAMES_PER_VIDEO = 128
INITIAL_PER_VIDEO = 4
MAX_TURNS = 6
MAX_FRAMES_PER_CALL = 12
IMAGE_SIDE = 448

SYSTEM_PROMPT = """Task Description:
Your task is to compare two videos generated based on the same text prompt by analyzing their frames in detail and provide an overall judgment along with a judgment for each evaluation dimension.

The provided frames are downsampled from these videos:
- Video 1: First four input frames.
- Video 2: Next four input frames.

Evaluation Dimensions:
1. Text Alignment (TA): How faithfully each video reflects the text prompt.
2. Visual Quality (VQ): Aesthetics, artifacts, blurriness, distortion, color, resolution, flickering.
3. Motion Quality (MQ): Smoothness, jitter, unnatural motion, temporal consistency.
4. Overall Assessment (OA): Holistic judgment across the above.

Frames and Analysis Rules:
- 8 sampled frames are provided initially (4 per video), evenly downsampled from 128 frames per video. The first 4 are Video 1, the next 4 are Video 2.
- Each video has 128 frames (indices 0-127). To inspect more frames, call select_frames with the indices you need; the tool retrieves the same indices from both videos symmetrically.
- Tool returns are paired: for [i, j, k] you get (v1[i], v2[i], v1[j], v2[j], v1[k], v2[k]). Use this pairing to compare the same moment across both videos.
- Each tool call accepts at most 12 indices.

Format Requirement:
1. <Snapshot></Snapshot> — summarize useful visual details after receiving frames.
2. <Think></Think> — reasoning.
3. <Answer></Answer> — final judgment.

Label semantics: 1 = Video 1 better, 2 = Video 2 better, 0 = tie.

Examples:
<Answer>TA=1, VQ=1, MQ=0, OA=1</Answer>

Tool call format:
When you want to inspect more frames, emit a tool call inside <tool_call></tool_call> tags:
<tool_call>{"name": "select_frames", "arguments": {"frame_indices": [10, 30, 60, 90]}}</tool_call>
"""


def extract_frames(video_path: str, indices: list[int]) -> list[Image.Image]:
    """Return PIL frames at the given indices, evenly mapped over the video's actual length."""
    cap = cv2.VideoCapture(video_path)
    total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    out: list[Image.Image] = []
    for idx in indices:
        # map idx in [0, FRAMES_PER_VIDEO) -> real frame in [0, total)
        real = min(int(idx / FRAMES_PER_VIDEO * total), total - 1)
        cap.set(cv2.CAP_PROP_POS_FRAMES, real)
        ok, frame = cap.read()
        if not ok:
            continue
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        img = Image.fromarray(frame).resize((IMAGE_SIDE, IMAGE_SIDE))
        out.append(img)
    cap.release()
    return out


def initial_frames(v1: str, v2: str) -> list[Image.Image]:
    idxs = [int(FRAMES_PER_VIDEO * (i + 0.5) / INITIAL_PER_VIDEO) for i in range(INITIAL_PER_VIDEO)]
    return extract_frames(v1, idxs) + extract_frames(v2, idxs)


def tool_frames(v1: str, v2: str, indices: list[int]) -> list[Image.Image]:
    indices = indices[:MAX_FRAMES_PER_CALL]
    out: list[Image.Image] = []
    for i in indices:
        out += extract_frames(v1, [i])
        out += extract_frames(v2, [i])
    return out


def parse_tool_call(text: str) -> dict | None:
    m = re.search(r"<tool_call>\s*(\{.*?\})\s*</tool_call>", text, re.DOTALL)
    if not m:
        return None
    try:
        obj = json.loads(m.group(1))
        return obj.get("arguments", {})
    except json.JSONDecodeError:
        return None


def parse_answer(text: str) -> dict | None:
    m = re.search(r"<Answer>(.*?)</Answer>", text, re.DOTALL | re.IGNORECASE)
    if not m:
        return None
    body = m.group(1)
    return {d: int(re.search(rf"{d}\s*=\s*(\d)", body).group(1))
            for d in ("TA", "MQ", "VQ", "OA")
            if re.search(rf"{d}\s*=\s*(\d)", body)}


@torch.inference_mode()
def run(video1: str, video2: str, prompt: str) -> dict:
    processor = AutoProcessor.from_pretrained(MODEL_DIR, trust_remote_code=True)
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        MODEL_DIR, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True
    ).eval()

    images = initial_frames(video1, video2)
    user_text = (
        f"Compare the two videos generated from the following prompt and evaluate them "
        f"across Text Alignment (TA), Motion Quality (MQ), Visual Quality (VQ), and "
        f"Overall Assessment (OA).\n\nPrompt: {prompt}\n\n"
        f"The first 4 images are uniformly sampled from Video 1, and the next 4 are from "
        f"Video 2. Each video has 128 frames (indices 0-127). "
        f"Use the select_frames tool to request additional frames if needed."
    )
    messages = [
        {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
        {"role": "user", "content": [{"type": "image"}] * len(images)
                                    + [{"type": "text", "text": user_text}]},
    ]

    for turn in range(MAX_TURNS):
        text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        inputs = processor(text=[text], images=images, return_tensors="pt", padding=True).to(model.device)
        output_ids = model.generate(**inputs, max_new_tokens=2048, do_sample=False, temperature=0.0)
        reply = processor.batch_decode(
            output_ids[:, inputs.input_ids.shape[1]:], skip_special_tokens=True
        )[0]

        messages.append({"role": "assistant", "content": [{"type": "text", "text": reply}]})

        answer = parse_answer(reply)
        if answer:
            return answer

        call = parse_tool_call(reply)
        if not call or "frame_indices" not in call:
            return parse_answer(reply) or {"TA": None, "MQ": None, "VQ": None, "OA": None}

        new_imgs = tool_frames(video1, video2, call["frame_indices"])
        images += new_imgs
        messages.append({
            "role": "user",
            "content": [{"type": "image"}] * len(new_imgs)
                       + [{"type": "text",
                           "text": f"<tool_response>Retrieved {len(call['frame_indices'])} "
                                   f"frame pairs ({call['frame_indices']}) symmetrically from both "
                                   f"videos.</tool_response>"}],
        })

    return {"TA": None, "MQ": None, "VQ": None, "OA": None}


if __name__ == "__main__":
    p = argparse.ArgumentParser()
    p.add_argument("--video1", required=True)
    p.add_argument("--video2", required=True)
    p.add_argument("--prompt", required=True)
    args = p.parse_args()
    print(json.dumps(run(args.video1, args.video2, args.prompt), indent=2))
```

## Output

```
{
  "TA": 1,
  "MQ": 0,
  "VQ": 2,
  "OA": 1
}
```

`1` = Video 1 wins on that dimension, `2` = Video 2 wins, `0` = tie.

## Hardware

Requires ~16 GB GPU memory in bf16. Tested on a single A100/H100.