VEFX-Reward commited on
Commit
f666f1f
·
verified ·
1 Parent(s): 982d249

Add VEFX-Bench reference code

Browse files
.gitignore ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *.egg-info/
5
+ dist/
6
+ build/
7
+ *.egg
8
+
9
+ # Environment
10
+ .env
11
+ *.env
12
+
13
+ # IDE
14
+ .vscode/
15
+ .idea/
16
+
17
+ # Model weights
18
+ *.safetensors
19
+ *.pth
20
+ *.bin
21
+ *.ckpt
22
+ *.onnx
23
+
24
+ # Data
25
+ *.avi
26
+ *.mov
27
+
28
+ # Keep sample videos for examples
29
+ !examples/sample_videos/*.mp4
30
+
31
+ # OS
32
+ .DS_Store
33
+ Thumbs.db
34
+
35
+ # Outputs
36
+ merged_model/
37
+ results/
LICENSE ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to the Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by the Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding any notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
examples/batch_scoring.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Batch scoring: Evaluate multiple video edits from a CSV file.
3
+
4
+ Expected CSV format:
5
+ original_video,edited_video,instruction
6
+ path/to/orig1.mp4,path/to/edit1.mp4,"make it snowy"
7
+ path/to/orig2.mp4,path/to/edit2.mp4,"add a red hat"
8
+
9
+ Usage:
10
+ python examples/batch_scoring.py \
11
+ --csv edits.csv \
12
+ --output results.csv
13
+ """
14
+
15
+ import argparse
16
+ import csv
17
+
18
+ import torch
19
+ from vefx_reward import VEFXReward
20
+
21
+
22
+ def main():
23
+ parser = argparse.ArgumentParser(description="Batch score video edits")
24
+ parser.add_argument("--csv", required=True, help="Input CSV with columns: original_video, edited_video, instruction")
25
+ parser.add_argument("--output", default="results.csv", help="Output CSV path")
26
+ parser.add_argument("--model", default="VEFX-Reward/VEFX-Reward-4B")
27
+ parser.add_argument("--device", default="cuda")
28
+ args = parser.parse_args()
29
+
30
+ model = VEFXReward(args.model, device=args.device)
31
+
32
+ with open(args.csv) as f:
33
+ rows = list(csv.DictReader(f))
34
+ print(f"Loaded {len(rows)} samples from {args.csv}")
35
+
36
+ results = []
37
+ for i, row in enumerate(rows):
38
+ scores = model.score(row["original_video"], row["edited_video"], row["instruction"])
39
+ results.append({**row, **scores})
40
+ print(f"[{i+1}/{len(rows)}] IF={scores['IF']:.2f} RQ={scores['RQ']:.2f} EE={scores['EE']:.2f} Overall={scores['Overall']:.2f}")
41
+
42
+ fieldnames = list(rows[0].keys()) + ["IF", "RQ", "EE", "Overall"]
43
+ with open(args.output, "w", newline="") as f:
44
+ writer = csv.DictWriter(f, fieldnames=fieldnames)
45
+ writer.writeheader()
46
+ writer.writerows(results)
47
+ print(f"\nResults saved to {args.output}")
48
+
49
+
50
+ if __name__ == "__main__":
51
+ main()
examples/multi_gpu_scoring.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Multi-GPU parallel scoring using subprocess workers.
3
+
4
+ Splits a CSV of video edits across multiple GPUs for faster inference.
5
+
6
+ Usage:
7
+ python examples/multi_gpu_scoring.py \
8
+ --csv edits.csv \
9
+ --output results.csv \
10
+ --num_gpus 4
11
+ """
12
+
13
+ import argparse
14
+ import csv
15
+ import json
16
+ import os
17
+ import subprocess
18
+ import sys
19
+ import tempfile
20
+
21
+
22
+ def worker_main(args):
23
+ """Single-GPU worker: load model, score shard, write results."""
24
+ import torch
25
+ from vefx_reward import VEFXReward
26
+
27
+ with open(args.shard_file) as f:
28
+ shard = json.load(f)
29
+
30
+ model = VEFXReward(args.model, device="cuda:0")
31
+
32
+ results = []
33
+ for i, item in enumerate(shard):
34
+ try:
35
+ scores = model.score(item["original_video"], item["edited_video"], item["instruction"])
36
+ results.append({**item, **scores})
37
+ print(f"[GPU {args.gpu_id}] [{i+1}/{len(shard)}] "
38
+ f"IF={scores['IF']:.2f} RQ={scores['RQ']:.2f} EE={scores['EE']:.2f}", flush=True)
39
+ except Exception as e:
40
+ print(f"[GPU {args.gpu_id}] [{i+1}/{len(shard)}] ERROR: {e}", flush=True)
41
+ results.append({**item, "IF": None, "RQ": None, "EE": None, "Overall": None, "error": str(e)})
42
+
43
+ with open(args.output_file, "w") as f:
44
+ json.dump(results, f)
45
+ print(f"[GPU {args.gpu_id}] Done — {len(results)} results", flush=True)
46
+
47
+
48
+ def main():
49
+ parser = argparse.ArgumentParser(description="Multi-GPU video edit scoring")
50
+ parser.add_argument("--csv", required=True, help="Input CSV")
51
+ parser.add_argument("--output", default="results.csv", help="Output CSV")
52
+ parser.add_argument("--model", default="VEFX-Reward/VEFX-Reward-4B")
53
+ parser.add_argument("--num_gpus", type=int, default=4)
54
+ # Internal worker args
55
+ parser.add_argument("--_worker", action="store_true", help=argparse.SUPPRESS)
56
+ parser.add_argument("--gpu_id", type=int, default=0, help=argparse.SUPPRESS)
57
+ parser.add_argument("--shard_file", type=str, default="", help=argparse.SUPPRESS)
58
+ parser.add_argument("--output_file", type=str, default="", help=argparse.SUPPRESS)
59
+ args = parser.parse_args()
60
+
61
+ if args._worker:
62
+ worker_main(args)
63
+ return
64
+
65
+ # --- Launcher mode ---
66
+ with open(args.csv) as f:
67
+ rows = list(csv.DictReader(f))
68
+ print(f"Loaded {len(rows)} samples, distributing across {args.num_gpus} GPUs")
69
+
70
+ items = [dict(row) for row in rows]
71
+ shards = [[] for _ in range(args.num_gpus)]
72
+ for i, item in enumerate(items):
73
+ shards[i % args.num_gpus].append(item)
74
+
75
+ tmpdir = tempfile.mkdtemp(prefix="vefx_multi_")
76
+ script = os.path.abspath(__file__)
77
+ procs = []
78
+ for gid in range(args.num_gpus):
79
+ if not shards[gid]:
80
+ continue
81
+ sf = os.path.join(tmpdir, f"shard_{gid}.json")
82
+ of = os.path.join(tmpdir, f"result_{gid}.json")
83
+ with open(sf, "w") as f:
84
+ json.dump(shards[gid], f)
85
+ env = os.environ.copy()
86
+ env["CUDA_VISIBLE_DEVICES"] = str(gid)
87
+ env["TOKENIZERS_PARALLELISM"] = "false"
88
+ p = subprocess.Popen(
89
+ [sys.executable, script,
90
+ "--_worker", "--gpu_id", str(gid),
91
+ "--shard_file", sf, "--output_file", of,
92
+ "--model", args.model],
93
+ env=env, stdout=sys.stdout, stderr=sys.stderr,
94
+ )
95
+ procs.append((p, of))
96
+
97
+ for p, _ in procs:
98
+ p.wait()
99
+
100
+ # Merge results
101
+ all_results = []
102
+ for _, of in procs:
103
+ if os.path.exists(of):
104
+ with open(of) as f:
105
+ all_results.extend(json.load(f))
106
+
107
+ fieldnames = list(rows[0].keys()) + ["IF", "RQ", "EE", "Overall"]
108
+ with open(args.output, "w", newline="") as f:
109
+ writer = csv.DictWriter(f, fieldnames=fieldnames, extrasaction="ignore")
110
+ writer.writeheader()
111
+ writer.writerows(all_results)
112
+ print(f"\nAll done — {len(all_results)} results saved to {args.output}")
113
+
114
+ import shutil
115
+ shutil.rmtree(tmpdir, ignore_errors=True)
116
+
117
+
118
+ if __name__ == "__main__":
119
+ main()
examples/quick_start.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Quick start: Score a single video edit with VEFX-Reward.
3
+
4
+ Usage:
5
+ python examples/quick_start.py \
6
+ --original path/to/original.mp4 \
7
+ --edited path/to/edited.mp4 \
8
+ --instruction "add a hat to the person"
9
+ """
10
+
11
+ import argparse
12
+ import torch
13
+ from vefx_reward import VEFXReward
14
+
15
+
16
+ def main():
17
+ parser = argparse.ArgumentParser(description="Score a video edit with VEFX-Reward")
18
+ parser.add_argument("--original", required=True, help="Path to original video")
19
+ parser.add_argument("--edited", required=True, help="Path to edited video")
20
+ parser.add_argument("--instruction", required=True, help="Editing instruction")
21
+ parser.add_argument("--model", default="VEFX-Reward/VEFX-Reward-4B", help="Model path or HF ID")
22
+ parser.add_argument("--device", default="cuda", help="Device (cuda / cpu)")
23
+ args = parser.parse_args()
24
+
25
+ model = VEFXReward(args.model, device=args.device)
26
+ scores = model.score(args.original, args.edited, args.instruction)
27
+
28
+ print("\n" + "=" * 50)
29
+ print("VEFX-Reward Scores")
30
+ print("=" * 50)
31
+ print(f" Instructional Following (IF): {scores['IF']:.2f}")
32
+ print(f" Render Quality (RQ): {scores['RQ']:.2f}")
33
+ print(f" Edit Exclusivity (EE): {scores['EE']:.2f}")
34
+ print(f" Overall : {scores['Overall']:.2f}")
35
+ print("=" * 50)
36
+
37
+
38
+ if __name__ == "__main__":
39
+ main()
examples/sample_videos/edited.mp4 ADDED
Binary file (21.5 kB). View file
 
examples/sample_videos/original.mp4 ADDED
Binary file (21.4 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.1.0
2
+ torchvision>=0.16.0
3
+ transformers>=4.51.0
4
+ accelerate>=0.30.0
5
+ safetensors>=0.4.0
6
+ huggingface_hub>=0.20.0
7
+ Pillow>=10.0.0
8
+ numpy>=1.24.0
9
+ requests
10
+ packaging
11
+ decord>=0.6.0
scripts/prepare_and_upload.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Prepare and upload VEFX-Reward model to HuggingFace Hub.
3
+
4
+ This script:
5
+ 1. Loads the Qwen3-VL-4B base model
6
+ 2. Manually merges LoRA weights into the base model
7
+ 3. Loads non-LoRA weights (rm_head, merger, special token embeddings)
8
+ 4. Saves and uploads the complete model to HuggingFace
9
+
10
+ Prerequisites:
11
+ pip install huggingface_hub safetensors
12
+ huggingface-cli login
13
+
14
+ Usage:
15
+ python scripts/prepare_and_upload.py \
16
+ --checkpoint_dir /path/to/training/logs/v4/ord_4B_lora_2stage_promptv2_res399k \
17
+ --checkpoint_step 1050 \
18
+ --hf_repo VEFX-Reward/VEFX-Reward-4B \
19
+ --output_dir ./merged_model
20
+ """
21
+
22
+ import argparse
23
+ import json
24
+ import os
25
+
26
+ import safetensors.torch as st
27
+ import torch
28
+ from transformers import AutoProcessor, AutoTokenizer
29
+
30
+ import sys
31
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
32
+ from vefx_reward.model import Qwen3VLRewardModelBT
33
+
34
+ SPECIAL_TOKENS = [
35
+ "<|VQ_reward|>", "<|MQ_reward|>", "<|TA_reward|>",
36
+ "<|IF_reward|>", "<|RQ_reward|>", "<|EE_reward|>",
37
+ ]
38
+
39
+
40
+ def main():
41
+ parser = argparse.ArgumentParser(description="Prepare and upload VEFX-Reward model")
42
+ parser.add_argument("--checkpoint_dir", required=True,
43
+ help="Training output directory containing model_config.json and checkpoint-*/")
44
+ parser.add_argument("--checkpoint_step", type=int, default=-1,
45
+ help="Checkpoint step to use (-1 = latest)")
46
+ parser.add_argument("--hf_repo", default="VEFX-Reward/VEFX-Reward-4B",
47
+ help="HuggingFace repo ID to upload to")
48
+ parser.add_argument("--output_dir", default="./merged_model",
49
+ help="Local directory to save merged model before upload")
50
+ parser.add_argument("--upload", action="store_true",
51
+ help="Actually upload to HuggingFace (otherwise just save locally)")
52
+ args = parser.parse_args()
53
+
54
+ # 1. Load training config
55
+ config_path = os.path.join(args.checkpoint_dir, "model_config.json")
56
+ with open(config_path) as f:
57
+ config_dict = json.load(f)
58
+
59
+ model_config = config_dict["model_config"]
60
+ data_config = config_dict["data_config"]
61
+
62
+ base_model_path = model_config["model_name_or_path"]
63
+ output_dim = model_config["output_dim"]
64
+ use_ordinal = model_config["use_ordinal"]
65
+ num_classes = model_config["num_classes"]
66
+ reward_token = model_config["reward_token"]
67
+
68
+ print(f"Base model: {base_model_path}")
69
+ print(f"Output dim: {output_dim}, Ordinal: {use_ordinal}, Num classes: {num_classes}")
70
+
71
+ # 2. Find checkpoint
72
+ import glob as globmod
73
+ ckpt_dirs = sorted(globmod.glob(os.path.join(args.checkpoint_dir, "checkpoint-*")),
74
+ key=lambda x: int(x.split("-")[-1]))
75
+ if args.checkpoint_step == -1:
76
+ ckpt_path = ckpt_dirs[-1]
77
+ else:
78
+ ckpt_path = os.path.join(args.checkpoint_dir, f"checkpoint-{args.checkpoint_step}")
79
+ print(f"Using checkpoint: {ckpt_path}")
80
+
81
+ # 3. Load processor from base model with checkpoint's tokenizer
82
+ processor = AutoProcessor.from_pretrained(base_model_path, padding_side="right")
83
+ ckpt_tokenizer_path = os.path.join(ckpt_path, "tokenizer")
84
+ if os.path.isdir(ckpt_tokenizer_path):
85
+ processor.tokenizer = AutoTokenizer.from_pretrained(ckpt_tokenizer_path)
86
+ else:
87
+ processor.tokenizer.add_special_tokens({"additional_special_tokens": SPECIAL_TOKENS})
88
+ special_token_ids = processor.tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS)
89
+ print(f"Tokenizer vocab size: {len(processor.tokenizer)}")
90
+
91
+ # 4. Load base model with reward head
92
+ print("Loading base model...")
93
+ model = Qwen3VLRewardModelBT.from_pretrained(
94
+ base_model_path,
95
+ torch_dtype=torch.bfloat16,
96
+ output_dim=output_dim,
97
+ reward_token=reward_token,
98
+ special_token_ids=special_token_ids,
99
+ use_ordinal=use_ordinal,
100
+ num_classes=num_classes,
101
+ use_cache=True,
102
+ device_map="cpu",
103
+ )
104
+ model.resize_token_embeddings(len(processor.tokenizer))
105
+ print(f"Model embeddings resized to {len(processor.tokenizer)}")
106
+
107
+ # 5. Manual LoRA merge (bypasses PEFT tie_word_embeddings issues)
108
+ print("Loading and merging adapter weights...")
109
+ adapter_weights = st.load_file(os.path.join(ckpt_path, "adapter_model.safetensors"), device="cpu")
110
+
111
+ with open(os.path.join(ckpt_path, "adapter_config.json")) as f:
112
+ lora_cfg = json.load(f)
113
+ scaling = lora_cfg["lora_alpha"] / lora_cfg["r"]
114
+
115
+ # Categorize adapter keys
116
+ base_layers, lora_As, lora_Bs, emb_As, emb_Bs = {}, {}, {}, {}, {}
117
+ for k, v in adapter_weights.items():
118
+ ck = k.replace("base_model.model.", "")
119
+ if ".base_layer.weight" in ck:
120
+ base_layers[ck.replace(".base_layer.weight", "")] = v
121
+ elif ".lora_A.weight" in ck:
122
+ lora_As[ck.replace(".lora_A.weight", "")] = v
123
+ elif ".lora_B.weight" in ck:
124
+ lora_Bs[ck.replace(".lora_B.weight", "")] = v
125
+ elif ".lora_embedding_A" in ck:
126
+ emb_As[ck.replace(".lora_embedding_A", "")] = v
127
+ elif ".lora_embedding_B" in ck:
128
+ emb_Bs[ck.replace(".lora_embedding_B", "")] = v
129
+
130
+ model_state = model.state_dict()
131
+
132
+ # Replace base layer weights (for resized lm_head / embed_tokens)
133
+ for mod, w in base_layers.items():
134
+ key = mod + ".weight"
135
+ if key in model_state:
136
+ model_state[key] = w.to(model_state[key].dtype)
137
+ print(f" Replaced base layer: {key}")
138
+
139
+ # Merge LoRA: W_merged = W + B @ A * scaling
140
+ merged_count = 0
141
+ for mod in lora_As:
142
+ if mod in lora_Bs:
143
+ A, B = lora_As[mod].float(), lora_Bs[mod].float()
144
+ delta = (B @ A) * scaling
145
+ key = mod + ".weight"
146
+ if key in model_state:
147
+ model_state[key] = (model_state[key].float() + delta).to(torch.bfloat16)
148
+ merged_count += 1
149
+ print(f" Merged {merged_count} LoRA modules")
150
+
151
+ # Merge embedding LoRA
152
+ for mod in emb_As:
153
+ if mod in emb_Bs:
154
+ A, B = emb_As[mod].float(), emb_Bs[mod].float()
155
+ delta = (B @ A).T * scaling
156
+ key = mod + ".weight"
157
+ if key in model_state:
158
+ model_state[key] = (model_state[key].float() + delta).to(torch.bfloat16)
159
+ print(f" Merged embedding LoRA: {key}")
160
+
161
+ # 6. Load non-LoRA weights (rm_head, merger, special embeddings)
162
+ non_lora_path = os.path.join(ckpt_path, "non_lora_state_dict.pth")
163
+ if os.path.exists(non_lora_path):
164
+ print("Loading non-LoRA weights...")
165
+ non_lora_weights = torch.load(non_lora_path, map_location="cpu")
166
+ for k, v in non_lora_weights.items():
167
+ ck = k.replace("base_model.model.", "")
168
+ if ck in model_state:
169
+ model_state[ck] = v.to(model_state[ck].dtype)
170
+ print(f" Loaded: {ck}")
171
+
172
+ model.load_state_dict(model_state)
173
+ print(f"All weights loaded. rm_head shape: {model.rm_head.weight.shape}")
174
+
175
+ # 7. Save merged model
176
+ os.makedirs(args.output_dir, exist_ok=True)
177
+ print(f"Saving merged model to {args.output_dir}...")
178
+ model.save_pretrained(args.output_dir, safe_serialization=True)
179
+ processor.save_pretrained(args.output_dir)
180
+
181
+ # Save VEFX-specific config
182
+ vefx_config = {
183
+ "output_dim": output_dim,
184
+ "use_ordinal": use_ordinal,
185
+ "num_classes": num_classes,
186
+ "reward_token": reward_token,
187
+ "fps": data_config.get("fps", 4.0),
188
+ "max_frame_pixels": data_config.get("max_frame_pixels", 399360),
189
+ "eval_dim": data_config.get("eval_dim", ["IF", "RQ", "EE"]),
190
+ "prompt_template_type": data_config.get("prompt_template_type", "editreward_v2_special"),
191
+ }
192
+ with open(os.path.join(args.output_dir, "vefx_config.json"), "w") as f:
193
+ json.dump(vefx_config, f, indent=2)
194
+ print("Saved vefx_config.json")
195
+
196
+ # 8. Upload to HuggingFace
197
+ if args.upload:
198
+ from huggingface_hub import HfApi
199
+ api = HfApi()
200
+ print(f"Uploading to {args.hf_repo}...")
201
+ api.upload_folder(
202
+ folder_path=args.output_dir,
203
+ repo_id=args.hf_repo,
204
+ repo_type="model",
205
+ )
206
+ print(f"Upload complete: https://huggingface.co/{args.hf_repo}")
207
+ else:
208
+ print(f"\nModel saved to {args.output_dir}")
209
+ print(f"To upload, run again with --upload flag, or manually:")
210
+ print(f" huggingface-cli upload {args.hf_repo} {args.output_dir} .")
211
+
212
+
213
+ if __name__ == "__main__":
214
+ main()
setup.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name="vefx-reward",
5
+ version="0.1.0",
6
+ description="VEFX-Reward: A reward model for video editing quality assessment",
7
+ long_description=open("README.md").read(),
8
+ long_description_content_type="text/markdown",
9
+ url="",
10
+ packages=find_packages(),
11
+ python_requires=">=3.10",
12
+ install_requires=[
13
+ "torch>=2.1.0",
14
+ "torchvision>=0.16.0",
15
+ "transformers>=4.51.0",
16
+ "accelerate>=0.30.0",
17
+ "safetensors>=0.4.0",
18
+ "huggingface_hub>=0.20.0",
19
+ "Pillow>=10.0.0",
20
+ "numpy>=1.24.0",
21
+ "requests",
22
+ "packaging",
23
+ "decord>=0.6.0",
24
+ ],
25
+ classifiers=[
26
+ "Programming Language :: Python :: 3",
27
+ "License :: OSI Approved :: Apache Software License",
28
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
29
+ ],
30
+ )
vefx_reward/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ VEFX-Reward: A reward model for video editing quality assessment.
3
+
4
+ Evaluates video edits on three dimensions (1–4 scale):
5
+ - IF (Instructional Following)
6
+ - RQ (Render Quality)
7
+ - EE (Edit Exclusivity)
8
+ """
9
+
10
+ __version__ = "0.1.0"
11
+
12
+ from .inference import VEFXReward
13
+
14
+ __all__ = ["VEFXReward"]
vefx_reward/inference.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ VEFX-Reward: Video editing quality assessment inference API.
3
+
4
+ Usage:
5
+ from vefx_reward import VEFXReward
6
+
7
+ model = VEFXReward("VEFX-Reward/VEFX-Reward-4B", device="cuda")
8
+ scores = model.score("original.mp4", "edited.mp4", "add a hat to the person")
9
+ # {'IF': 3.21, 'RQ': 2.85, 'EE': 3.54, 'Overall': 9.60}
10
+ """
11
+
12
+ import json
13
+ import os
14
+ from collections.abc import Mapping
15
+ from typing import Optional
16
+
17
+ import numpy as np
18
+ import torch
19
+ from transformers import AutoProcessor
20
+
21
+ from .model import Qwen3VLRewardModelBT, ordinal_predict
22
+ from .prompt_template import build_prompt
23
+ from .vision_process import process_vision_info
24
+
25
+ # Default model hyperparameters (matching the released VEFX-Reward-4B)
26
+ DEFAULT_FPS = 4.0
27
+ DEFAULT_MAX_FRAME_PIXELS = 399360
28
+ DEFAULT_NUM_CLASSES = 4
29
+ DEFAULT_OUTPUT_DIM = 3
30
+ DIMS = ["IF", "RQ", "EE"]
31
+
32
+ SPECIAL_TOKENS = [
33
+ "<|VQ_reward|>", "<|MQ_reward|>", "<|TA_reward|>",
34
+ "<|IF_reward|>", "<|RQ_reward|>", "<|EE_reward|>",
35
+ ]
36
+
37
+
38
+ class VEFXReward:
39
+ """VEFX-Reward model for video editing quality assessment.
40
+
41
+ Scores video edits on three dimensions (1–4 scale):
42
+ - **IF** (Instructional Following): How well the edit follows the instruction.
43
+ - **RQ** (Render Quality): Visual and temporal quality of the edited video.
44
+ - **EE** (Edit Exclusivity): Whether only the intended region was modified.
45
+
46
+ Args:
47
+ model_path: HuggingFace model ID or local path
48
+ (e.g., ``"VEFX-Reward/VEFX-Reward-4B"``).
49
+ device: Device string (default ``"cuda"``).
50
+ dtype: Torch dtype (default ``torch.bfloat16``).
51
+ fps: Frames per second for video sampling (default 4.0).
52
+ max_frame_pixels: Maximum pixels per frame (default 399360).
53
+
54
+ Example::
55
+
56
+ model = VEFXReward("VEFX-Reward/VEFX-Reward-4B")
57
+ scores = model.score("original.mp4", "edited.mp4", "make it snowy")
58
+ print(scores)
59
+ # {'IF': 3.2, 'RQ': 2.9, 'EE': 3.5, 'Overall': 9.6}
60
+ """
61
+
62
+ def __init__(
63
+ self,
64
+ model_path: str = "VEFX-Reward/VEFX-Reward-4B",
65
+ device: str = "cuda",
66
+ dtype: torch.dtype = torch.bfloat16,
67
+ fps: float = DEFAULT_FPS,
68
+ max_frame_pixels: int = DEFAULT_MAX_FRAME_PIXELS,
69
+ ):
70
+ self.device = device
71
+ self.dtype = dtype
72
+ self.fps = fps
73
+ self.max_frame_pixels = max_frame_pixels
74
+
75
+ # Load config
76
+ vefx_config_path = os.path.join(model_path, "vefx_config.json") if os.path.isdir(model_path) else None
77
+ if vefx_config_path and os.path.exists(vefx_config_path):
78
+ with open(vefx_config_path) as f:
79
+ vefx_config = json.load(f)
80
+ else:
81
+ # Try to download from HF hub
82
+ try:
83
+ from huggingface_hub import hf_hub_download
84
+ vefx_config_path = hf_hub_download(model_path, "vefx_config.json")
85
+ with open(vefx_config_path) as f:
86
+ vefx_config = json.load(f)
87
+ except Exception:
88
+ vefx_config = {}
89
+
90
+ self.num_classes = vefx_config.get("num_classes", DEFAULT_NUM_CLASSES)
91
+ self.output_dim = vefx_config.get("output_dim", DEFAULT_OUTPUT_DIM)
92
+ self.use_ordinal = vefx_config.get("use_ordinal", True)
93
+ reward_token = vefx_config.get("reward_token", "special")
94
+
95
+ # Load processor and add special tokens
96
+ self.processor = AutoProcessor.from_pretrained(model_path, padding_side="right")
97
+ existing_tokens = set(self.processor.tokenizer.get_vocab().keys())
98
+ tokens_to_add = [t for t in SPECIAL_TOKENS if t not in existing_tokens]
99
+ if tokens_to_add:
100
+ self.processor.tokenizer.add_special_tokens({"additional_special_tokens": tokens_to_add})
101
+ special_token_ids = self.processor.tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS)
102
+
103
+ # Load model
104
+ self.model = Qwen3VLRewardModelBT.from_pretrained(
105
+ model_path,
106
+ torch_dtype=dtype,
107
+ output_dim=self.output_dim,
108
+ reward_token=reward_token,
109
+ special_token_ids=special_token_ids,
110
+ use_ordinal=self.use_ordinal,
111
+ num_classes=self.num_classes,
112
+ use_cache=True,
113
+ )
114
+ self.model.resize_token_embeddings(len(self.processor.tokenizer))
115
+
116
+ self.model.eval().to(self.device)
117
+ print(f"VEFX-Reward loaded on {self.device} ({dtype})")
118
+
119
+ def _prepare_input(self, data):
120
+ if isinstance(data, Mapping):
121
+ return type(data)({k: self._prepare_input(v) for k, v in data.items()})
122
+ elif isinstance(data, (tuple, list)):
123
+ return type(data)(self._prepare_input(v) for v in data)
124
+ elif isinstance(data, torch.Tensor):
125
+ return data.to(device=self.device)
126
+ return data
127
+
128
+ def _build_batch(self, original_video: str, edited_video: str, instruction: str):
129
+ """Build a single-sample batch from video paths and instruction."""
130
+ content = [
131
+ {
132
+ "type": "video",
133
+ "video": f"file://{os.path.abspath(original_video)}",
134
+ "max_pixels": self.max_frame_pixels,
135
+ "fps": self.fps,
136
+ "sample_type": "uniform",
137
+ },
138
+ {
139
+ "type": "video",
140
+ "video": f"file://{os.path.abspath(edited_video)}",
141
+ "max_pixels": self.max_frame_pixels,
142
+ "fps": self.fps,
143
+ "sample_type": "uniform",
144
+ },
145
+ {"type": "text", "text": build_prompt(instruction)},
146
+ ]
147
+ messages = [[{"role": "user", "content": content}]]
148
+ image_inputs, video_inputs, video_metadata_list = process_vision_info(messages)
149
+ video_inputs = [v.float() / 255.0 for v in video_inputs]
150
+
151
+ texts = self.processor.apply_chat_template(
152
+ messages, tokenize=False, add_generation_prompt=True
153
+ )
154
+ processor_kwargs = dict(
155
+ text=texts,
156
+ images=image_inputs,
157
+ videos=video_inputs,
158
+ padding=True,
159
+ return_tensors="pt",
160
+ videos_kwargs={"do_rescale": False, "do_sample_frames": False},
161
+ )
162
+ if video_metadata_list:
163
+ processor_kwargs["videos_kwargs"]["video_metadata"] = video_metadata_list
164
+ processor_kwargs["videos_kwargs"]["return_metadata"] = True
165
+
166
+ batch = self.processor(**processor_kwargs)
167
+ return self._prepare_input(batch)
168
+
169
+ def _logits_to_scores(self, logits: torch.Tensor) -> dict:
170
+ """Convert raw ordinal logits to IF/RQ/EE scores."""
171
+ logits_np = logits.float().cpu().numpy()
172
+ if self.use_ordinal:
173
+ num_dims = self.output_dim
174
+ num_thresholds = self.num_classes - 1
175
+ logits_reshaped = logits_np.reshape(1, num_dims, num_thresholds)
176
+ hard, soft = ordinal_predict(logits_reshaped, self.num_classes)
177
+ scores = {DIMS[j]: round(float(soft[0, j]), 3) for j in range(num_dims)}
178
+ else:
179
+ scores = {DIMS[j]: round(float(logits_np[0, j]), 3) for j in range(self.output_dim)}
180
+ scores["Overall"] = round(sum(scores[d] for d in DIMS), 3)
181
+ return scores
182
+
183
+ @torch.no_grad()
184
+ def score(
185
+ self,
186
+ original_video: str,
187
+ edited_video: str,
188
+ instruction: str,
189
+ ) -> dict:
190
+ """Score a single video edit.
191
+
192
+ Args:
193
+ original_video: Path to the original (source) video.
194
+ edited_video: Path to the edited video.
195
+ instruction: The editing instruction text.
196
+
197
+ Returns:
198
+ Dictionary with keys ``'IF'``, ``'RQ'``, ``'EE'``, ``'Overall'``.
199
+ Each dimension is scored on a continuous 1–4 scale.
200
+ """
201
+ batch = self._build_batch(original_video, edited_video, instruction)
202
+ logits = self.model(**batch, return_dict=True)["logits"]
203
+ return self._logits_to_scores(logits)
204
+
205
+ @torch.no_grad()
206
+ def score_batch(
207
+ self,
208
+ original_videos: list[str],
209
+ edited_videos: list[str],
210
+ instructions: list[str],
211
+ ) -> list[dict]:
212
+ """Score multiple video edits (processed sequentially to avoid OOM).
213
+
214
+ Args:
215
+ original_videos: List of paths to original videos.
216
+ edited_videos: List of paths to edited videos.
217
+ instructions: List of editing instruction texts.
218
+
219
+ Returns:
220
+ List of score dictionaries, one per sample.
221
+ """
222
+ assert len(original_videos) == len(edited_videos) == len(instructions)
223
+ results = []
224
+ for orig, edit, inst in zip(original_videos, edited_videos, instructions):
225
+ results.append(self.score(orig, edit, inst))
226
+ return results
vefx_reward/model.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ VEFX-Reward: Qwen3-VL based reward model for video editing quality assessment.
3
+
4
+ Extends Qwen3VLForConditionalGeneration with an rm_head for ordinal regression,
5
+ scoring video edits on Instructional Following (IF), Render Quality (RQ),
6
+ and Edit Exclusivity (EE) on a 1–4 scale.
7
+ """
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ from typing import List, Optional
13
+ from transformers import Qwen3VLForConditionalGeneration
14
+
15
+
16
+ class Qwen3VLRewardModelBT(Qwen3VLForConditionalGeneration):
17
+ """Qwen3-VL with a reward head for ordinal video edit quality scoring."""
18
+
19
+ def __init__(self, config, output_dim=3, reward_token="special",
20
+ special_token_ids=None, use_ordinal=True, num_classes=4, **kwargs):
21
+ if 'use_cache' in kwargs:
22
+ config.use_cache = kwargs.pop('use_cache')
23
+ super().__init__(config, **kwargs)
24
+ self.output_dim = output_dim
25
+ self.rm_head = nn.Linear(config.text_config.hidden_size, output_dim, bias=False)
26
+ nn.init.normal_(self.rm_head.weight, mean=0.0, std=1.0 / config.text_config.hidden_size)
27
+ self.reward_token = reward_token
28
+ self.use_ordinal = use_ordinal
29
+ self.num_classes = num_classes
30
+ self.special_token_ids = special_token_ids
31
+ if self.special_token_ids is not None:
32
+ self.reward_token = "special"
33
+
34
+ def forward(
35
+ self,
36
+ input_ids: torch.LongTensor = None,
37
+ attention_mask: Optional[torch.Tensor] = None,
38
+ position_ids: Optional[torch.LongTensor] = None,
39
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
40
+ inputs_embeds: Optional[torch.FloatTensor] = None,
41
+ labels: Optional[torch.LongTensor] = None,
42
+ use_cache: Optional[bool] = None,
43
+ output_attentions: Optional[bool] = None,
44
+ output_hidden_states: Optional[bool] = None,
45
+ return_dict: Optional[bool] = None,
46
+ pixel_values: Optional[torch.Tensor] = None,
47
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
48
+ image_grid_thw: Optional[torch.LongTensor] = None,
49
+ video_grid_thw: Optional[torch.LongTensor] = None,
50
+ mm_token_type_ids: Optional[torch.IntTensor] = None,
51
+ **kwargs,
52
+ ):
53
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
54
+ output_hidden_states = (
55
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
56
+ )
57
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
58
+
59
+ outputs = self.model(
60
+ input_ids=input_ids,
61
+ position_ids=position_ids,
62
+ attention_mask=attention_mask,
63
+ past_key_values=past_key_values,
64
+ inputs_embeds=inputs_embeds,
65
+ pixel_values=pixel_values,
66
+ pixel_values_videos=pixel_values_videos,
67
+ image_grid_thw=image_grid_thw,
68
+ video_grid_thw=video_grid_thw,
69
+ mm_token_type_ids=mm_token_type_ids,
70
+ output_attentions=output_attentions,
71
+ output_hidden_states=output_hidden_states,
72
+ return_dict=return_dict,
73
+ **kwargs,
74
+ )
75
+
76
+ hidden_states = outputs[0] # [B, L, D]
77
+ logits = self.rm_head(hidden_states) # [B, L, output_dim]
78
+
79
+ if input_ids is not None:
80
+ batch_size = input_ids.shape[0]
81
+ else:
82
+ batch_size = inputs_embeds.shape[0]
83
+
84
+ pad_token_id = self.config.text_config.pad_token_id
85
+ if pad_token_id is None and batch_size != 1:
86
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
87
+ if pad_token_id is None:
88
+ sequence_lengths = -1
89
+ else:
90
+ if input_ids is not None:
91
+ sequence_lengths = torch.eq(input_ids, pad_token_id).int().argmax(-1) - 1
92
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
93
+ sequence_lengths = sequence_lengths.to(logits.device)
94
+ else:
95
+ sequence_lengths = -1
96
+
97
+ if self.reward_token == "last":
98
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
99
+ elif self.reward_token == "mean":
100
+ valid_lengths = torch.clamp(sequence_lengths, min=0, max=logits.size(1) - 1)
101
+ pooled_logits = torch.stack([logits[i, :valid_lengths[i]].mean(dim=0) for i in range(batch_size)])
102
+ elif self.reward_token == "special":
103
+ special_token_mask = torch.zeros_like(input_ids, dtype=torch.bool)
104
+ for special_token_id in self.special_token_ids:
105
+ special_token_mask = special_token_mask | (input_ids == special_token_id)
106
+ pooled_logits = logits[special_token_mask, ...]
107
+ num_matched = special_token_mask.sum(dim=1)
108
+ num_dims = num_matched[0].item()
109
+ pooled_logits = pooled_logits.view(batch_size, num_dims, -1)
110
+ if self.use_ordinal:
111
+ pooled_logits = pooled_logits.view(batch_size, -1)
112
+ else:
113
+ if self.output_dim == num_dims:
114
+ pooled_logits = pooled_logits.diagonal(dim1=1, dim2=2)
115
+ pooled_logits = pooled_logits.view(batch_size, -1)
116
+ else:
117
+ raise ValueError(f"Invalid reward_token: {self.reward_token}")
118
+
119
+ return {"logits": pooled_logits}
120
+
121
+
122
+ def ordinal_predict(logits: np.ndarray, num_classes: int):
123
+ """
124
+ Convert CORN ordinal logits to predicted scores.
125
+
126
+ Args:
127
+ logits: [B, D, K-1] raw threshold logits
128
+ num_classes: K (number of ordinal classes)
129
+
130
+ Returns:
131
+ hard_preds: [B, D] integer predictions in {1..K}
132
+ soft_preds: [B, D] continuous expected value E[Y]
133
+ """
134
+ probs = 1.0 / (1.0 + np.exp(-logits)) # sigmoid → P(Y>k | Y>=k)
135
+ cum_probs = np.cumprod(probs, axis=-1) # P(Y>k) = prod_{j<=k} P(Y>j|Y>=j)
136
+
137
+ hard_preds = (cum_probs > 0.5).sum(axis=-1) + 1 # [B, D]
138
+
139
+ cum_ext = np.concatenate([
140
+ np.ones((*cum_probs.shape[:-1], 1)),
141
+ cum_probs,
142
+ np.zeros((*cum_probs.shape[:-1], 1)),
143
+ ], axis=-1)
144
+ p_class = cum_ext[..., :-1] - cum_ext[..., 1:]
145
+ p_class = np.maximum(p_class, 0)
146
+ class_values = np.arange(1, num_classes + 1)
147
+ soft_preds = (p_class * class_values).sum(axis=-1)
148
+
149
+ return hard_preds, soft_preds
vefx_reward/prompt_template.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Prompt templates for VEFX-Reward video editing quality evaluation.
3
+ """
4
+
5
+ EDITREWARD_V2_SPECIAL = """You are an expert evaluator assessing the quality of AI-generated video edits. You will be provided with two videos:
6
+ - **Video 1**: The Original Video (before editing)
7
+ - **Video 2**: The Edited Video (after editing)
8
+
9
+ The editing instruction is:
10
+ "{text_prompt}"
11
+
12
+ Your task is to evaluate the Edited Video across THREE independent dimensions. Each dimension is scored on a 1–4 integer scale. **Scores across dimensions are independent** — a failure in one dimension must NOT affect scores in another.
13
+
14
+ ---
15
+
16
+ ## Dimension 1: Instructional Following (IF)
17
+ **Core question:** Does the edited video accurately reflect the semantic requirements of the editing instruction?
18
+
19
+ Evaluation criteria:
20
+ - Object replacement: If the instruction says "replace apple with orange," did the model actually generate an orange (not a lemon or tomato)?
21
+ - Action/attribute changes: If the instruction involves motion or attribute changes (e.g., "make it night"), was this correctly executed?
22
+ - Completeness: Were ALL parts of the instruction addressed, not just partial execution?
23
+
24
+ Scoring rubric:
25
+ - **4 (Perfect):** The edit precisely and completely executes all instructions. Object categories, attributes (color, shape), actions, and styles all match the instruction with no ambiguity.
26
+ - **3 (High):** The main instruction was executed, but minor details deviate. E.g., instruction asks for "red sports car" but a "red truck" was generated — the main concept "red car" is correct.
27
+ - **2 (Low):** The main instruction was partially executed but with significant deviations, or completely irrelevant modifications were made.
28
+ - **1 (Failed):** The edit has no relation to the instruction. E.g., instruction asks for "night scene" but the video remains daytime, or no change occurred at all.
29
+
30
+ **Important notes:**
31
+ - If the edit instruction asks for a camera perspective change (e.g., "shift to high angle") and the video shows no actual perspective change, score 1.
32
+ - If the instruction asks for adding/increasing objects and no new objects appear, score 1.
33
+ - A video that looks identical to the original (no edit happened) always scores 1 for IF.
34
+
35
+ Instructional Following score (integer 1-4): <|IF_reward|>
36
+
37
+ ---
38
+
39
+ ## Dimension 2: Render Quality (RQ)
40
+ **Core question:** What is the visual and temporal quality of the edited video?
41
+
42
+ Evaluation criteria:
43
+ - Naturalness and clarity: Are all parts of the video natural and sharp? Any blurriness, noise, or artifacts?
44
+ - Physical plausibility: Does object motion obey physics? Any flickering, jittering, objects disappearing/morphing unexpectedly?
45
+ - Temporal consistency: Is the video smooth frame-to-frame? Any sudden jumps, abrupt texture/color changes between frames?
46
+
47
+ Scoring rubric:
48
+ - **4 (Excellent):** Video clarity is very high with no visible defects, or only extremely minor artifacts detectable on very close inspection. Object motion fully obeys physics, smooth and natural. Visual quality is on par with or better than the original.
49
+ - **3 (Medium):** Some quality degradation exists (e.g., slight blurring, localized flickering), but all objects remain clearly identifiable. The video's overall structure is intact despite imperfections.
50
+ - **2 (Poor):** Significant quality degradation with obvious artifacts, distortion, or frame-to-frame inconsistency. Some object outlines deform, motion appears unnatural, affecting viewing experience.
51
+ - **1 (Unusable):** Video quality completely breaks down. Objects are severely deformed or unrecognizable, serious physics violations (e.g., person walking through walls, objects shattering spontaneously), heavy noise or complete blur.
52
+
53
+ **Important notes:**
54
+ - A sudden scene transition mid-video (e.g., white background abruptly becoming a construction site) counts as a physics/consistency violation — score ≤ 3.
55
+ - If the edit did NOT happen (original is preserved), RQ can still be high if the video itself looks fine — evaluate the video's visual quality independently.
56
+ - Evaluate temporal artifacts carefully: a single frame of flickering is minor (score 3), persistent warping or morphing is severe (score 1-2).
57
+
58
+ Render Quality score (integer 1-4): <|RQ_reward|>
59
+
60
+ ---
61
+
62
+ ## Dimension 3: Edit Exclusivity (EE)
63
+ **Core question:** Did the model ONLY perform the specified edit, without making unintended changes to other parts of the video?
64
+
65
+ Evaluation criteria:
66
+ - Over-editing: When editing a foreground object, did the background, lighting, or other unrelated objects change?
67
+ - Scene consistency: Are pixels, textures, and structures in non-edited regions preserved?
68
+ - Camera trajectory: Was the original camera movement preserved? (Changing camera motion when not instructed is over-editing.)
69
+ - Identity preservation: Do unedited people maintain their facial features, expressions, and body movements?
70
+
71
+ Scoring rubric:
72
+ - **4 (Perfect):** Strict exclusivity maintained. Only the target region specified by the instruction changed. All other regions (background, unrelated objects) remain identical to the original. Tiny pixel-level differences invisible to the eye are acceptable.
73
+ - **3 (Medium):** Visible over-editing occurred. Non-target areas show noticeable changes, but overall scene layout and unrelated object consistency are still preserved. E.g., replaced a cup on a table but the table style also changed, or a background window disappeared.
74
+ - **2 (Poor):** The overall scene or multiple unrelated objects changed significantly.
75
+ - **1 (Complete failure):** No exclusivity at all. The entire video looks like a completely new video. The surrounding scene changed drastically, or more than three unrelated objects underwent serious alterations.
76
+
77
+ **Important notes:**
78
+ - Camera trajectory changes (when not instructed) are over-editing — if the original video had camera motion and the edited video is static (or vice versa), penalize EE.
79
+ - For style transfer instructions (e.g., "turn into cyberpunk style"), it is expected that the entire visual style changes — this is NOT over-editing. But if text content or distinct object identities are destroyed during style transfer, that IS over-editing (score ≤ 3).
80
+ - If the edit failed (IF=1) but the rest of the video also changed, EE should still be scored low.
81
+
82
+ Edit Exclusivity score (integer 1-4): <|EE_reward|>
83
+ """
84
+
85
+
86
+ def build_prompt(instruction: str) -> str:
87
+ """Build the evaluation prompt from an editing instruction.
88
+
89
+ Args:
90
+ instruction: The video editing instruction text.
91
+
92
+ Returns:
93
+ The formatted prompt string with special reward tokens.
94
+ """
95
+ return EDITREWARD_V2_SPECIAL.format(text_prompt=instruction)
vefx_reward/vision_process.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Video processing utilities for VEFX-Reward.
3
+ Handles video loading, frame sampling, and resizing for Qwen3-VL input.
4
+ Adapted from qwen-vl-utils (https://github.com/kq-chen/qwen-vl-utils).
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import base64
10
+ import logging
11
+ import math
12
+ import os
13
+ import sys
14
+ import warnings
15
+ from functools import lru_cache
16
+ from io import BytesIO
17
+
18
+ import requests
19
+ import torch
20
+ import torchvision
21
+ from packaging import version
22
+ from PIL import Image
23
+ from torchvision import io, transforms
24
+ from torchvision.transforms import InterpolationMode
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+ IMAGE_FACTOR = 28
29
+ MIN_PIXELS = 4 * 28 * 28
30
+ MAX_PIXELS = 16384 * 28 * 28
31
+ MAX_RATIO = 200
32
+
33
+ VIDEO_MIN_PIXELS = 128 * 28 * 28
34
+ VIDEO_MAX_PIXELS = 768 * 28 * 28
35
+ VIDEO_TOTAL_PIXELS = 24576 * 28 * 28
36
+ FRAME_FACTOR = 2
37
+ FPS = 2.0
38
+ FPS_MIN_FRAMES = 4
39
+ FPS_MAX_FRAMES = 768
40
+
41
+
42
+ def round_by_factor(number: int, factor: int) -> int:
43
+ return round(number / factor) * factor
44
+
45
+
46
+ def ceil_by_factor(number: int, factor: int) -> int:
47
+ return math.ceil(number / factor) * factor
48
+
49
+
50
+ def floor_by_factor(number: int, factor: int) -> int:
51
+ return math.floor(number / factor) * factor
52
+
53
+
54
+ def smart_resize(
55
+ height: int, width: int, factor: int = IMAGE_FACTOR,
56
+ min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS,
57
+ ) -> tuple[int, int]:
58
+ """Resize dimensions to be divisible by factor while respecting pixel budget."""
59
+ if max(height, width) / min(height, width) > MAX_RATIO:
60
+ raise ValueError(
61
+ f"absolute aspect ratio must be smaller than {MAX_RATIO}, "
62
+ f"got {max(height, width) / min(height, width)}"
63
+ )
64
+ h_bar = max(factor, round_by_factor(height, factor))
65
+ w_bar = max(factor, round_by_factor(width, factor))
66
+ if h_bar * w_bar > max_pixels:
67
+ beta = math.sqrt((height * width) / max_pixels)
68
+ h_bar = floor_by_factor(height / beta, factor)
69
+ w_bar = floor_by_factor(width / beta, factor)
70
+ elif h_bar * w_bar < min_pixels:
71
+ beta = math.sqrt(min_pixels / (height * width))
72
+ h_bar = ceil_by_factor(height * beta, factor)
73
+ w_bar = ceil_by_factor(width * beta, factor)
74
+ return h_bar, w_bar
75
+
76
+
77
+ def smart_nframes(ele: dict, total_frames: int, video_fps: int | float) -> int:
78
+ """Calculate the number of frames to extract based on fps or nframes config."""
79
+ assert not ("fps" in ele and "nframes" in ele), "Only accept either `fps` or `nframes`"
80
+ if "nframes" in ele:
81
+ nframes = round_by_factor(ele["nframes"], FRAME_FACTOR)
82
+ else:
83
+ fps = ele.get("fps", FPS)
84
+ min_frames = ceil_by_factor(ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR)
85
+ max_frames = floor_by_factor(
86
+ ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), FRAME_FACTOR
87
+ )
88
+ nframes = total_frames / video_fps * fps
89
+ nframes = min(max(nframes, min_frames), max_frames)
90
+ nframes = round_by_factor(nframes, FRAME_FACTOR)
91
+ if nframes > total_frames:
92
+ nframes = total_frames
93
+ if not (FRAME_FACTOR <= nframes <= total_frames):
94
+ raise ValueError(f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}.")
95
+ return nframes
96
+
97
+
98
+ def _read_video_torchvision(ele: dict) -> tuple[torch.Tensor, dict]:
99
+ """Read video using torchvision.io.read_video. Returns (T, C, H, W) tensor."""
100
+ video_path = ele["video"]
101
+ if version.parse(torchvision.__version__) < version.parse("0.19.0"):
102
+ if "http://" in video_path or "https://" in video_path:
103
+ warnings.warn("torchvision < 0.19.0 does not support http/https video path.")
104
+ if "file://" in video_path:
105
+ video_path = video_path[7:]
106
+ video, audio, info = io.read_video(
107
+ video_path,
108
+ start_pts=ele.get("video_start", 0.0),
109
+ end_pts=ele.get("video_end", None),
110
+ pts_unit="sec",
111
+ output_format="TCHW",
112
+ )
113
+ total_frames, video_fps = video.size(0), info["video_fps"]
114
+ nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
115
+ idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
116
+ video = video[idx]
117
+ metadata = {
118
+ "total_num_frames": total_frames,
119
+ "fps": video_fps,
120
+ "frames_indices": idx,
121
+ }
122
+ return video, metadata
123
+
124
+
125
+ def is_decord_available() -> bool:
126
+ import importlib.util
127
+ return importlib.util.find_spec("decord") is not None
128
+
129
+
130
+ def _read_video_decord(ele: dict) -> tuple[torch.Tensor, dict]:
131
+ """Read video using decord.VideoReader. Returns (T, C, H, W) tensor."""
132
+ import decord
133
+ video_path = ele["video"]
134
+ vr = decord.VideoReader(video_path)
135
+ total_frames, video_fps = len(vr), vr.get_avg_fps()
136
+ nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
137
+ idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
138
+ video = vr.get_batch(idx).asnumpy()
139
+ video = torch.tensor(video).permute(0, 3, 1, 2) # NHWC → TCHW
140
+ metadata = {
141
+ "total_num_frames": total_frames,
142
+ "fps": video_fps,
143
+ "frames_indices": idx,
144
+ }
145
+ return video, metadata
146
+
147
+
148
+ VIDEO_READER_BACKENDS = {
149
+ "decord": _read_video_decord,
150
+ "torchvision": _read_video_torchvision,
151
+ }
152
+
153
+ FORCE_QWENVL_VIDEO_READER = os.getenv("FORCE_QWENVL_VIDEO_READER", None)
154
+
155
+
156
+ @lru_cache(maxsize=1)
157
+ def get_video_reader_backend() -> str:
158
+ if FORCE_QWENVL_VIDEO_READER is not None:
159
+ video_reader_backend = FORCE_QWENVL_VIDEO_READER
160
+ elif is_decord_available():
161
+ video_reader_backend = "decord"
162
+ else:
163
+ video_reader_backend = "torchvision"
164
+ print(f"vefx-reward using {video_reader_backend} to read video.", file=sys.stderr)
165
+ return video_reader_backend
166
+
167
+
168
+ def fetch_image(ele: dict[str, str | Image.Image], size_factor: int = IMAGE_FACTOR) -> Image.Image:
169
+ if "image" in ele:
170
+ image = ele["image"]
171
+ else:
172
+ image = ele["image_url"]
173
+ image_obj = None
174
+ if isinstance(image, Image.Image):
175
+ image_obj = image
176
+ elif image.startswith("http://") or image.startswith("https://"):
177
+ image_obj = Image.open(requests.get(image, stream=True).raw)
178
+ elif image.startswith("file://"):
179
+ image_obj = Image.open(image[7:])
180
+ elif image.startswith("data:image"):
181
+ if "base64," in image:
182
+ _, base64_data = image.split("base64,", 1)
183
+ data = base64.b64decode(base64_data)
184
+ image_obj = Image.open(BytesIO(data))
185
+ else:
186
+ image_obj = Image.open(image)
187
+ if image_obj is None:
188
+ raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
189
+ image = image_obj.convert("RGB")
190
+ if "resized_height" in ele and "resized_width" in ele:
191
+ resized_height, resized_width = smart_resize(
192
+ ele["resized_height"], ele["resized_width"], factor=size_factor,
193
+ )
194
+ else:
195
+ width, height = image.size
196
+ min_pixels = ele.get("min_pixels", MIN_PIXELS)
197
+ max_pixels = ele.get("max_pixels", MAX_PIXELS)
198
+ resized_height, resized_width = smart_resize(
199
+ height, width, factor=size_factor, min_pixels=min_pixels, max_pixels=max_pixels,
200
+ )
201
+ image = image.resize((resized_width, resized_height))
202
+ return image
203
+
204
+
205
+ def fetch_video(ele: dict, image_factor: int = IMAGE_FACTOR) -> tuple[torch.Tensor | list[Image.Image], dict | None]:
206
+ if isinstance(ele["video"], str):
207
+ video_reader_backend = get_video_reader_backend()
208
+ video, metadata = VIDEO_READER_BACKENDS[video_reader_backend](ele)
209
+ nframes, _, height, width = video.shape
210
+ min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS)
211
+ total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS)
212
+ max_pixels = max(min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR), int(min_pixels * 1.05))
213
+ max_pixels = ele.get("max_pixels", max_pixels)
214
+ if "resized_height" in ele and "resized_width" in ele:
215
+ resized_height, resized_width = smart_resize(
216
+ ele["resized_height"], ele["resized_width"], factor=image_factor,
217
+ )
218
+ else:
219
+ resized_height, resized_width = smart_resize(
220
+ height, width, factor=image_factor,
221
+ min_pixels=min_pixels, max_pixels=max_pixels,
222
+ )
223
+ video = transforms.functional.resize(
224
+ video, [resized_height, resized_width],
225
+ interpolation=InterpolationMode.BICUBIC, antialias=True,
226
+ ).float()
227
+ return video, metadata
228
+ else:
229
+ assert isinstance(ele["video"], (list, tuple))
230
+ process_info = ele.copy()
231
+ process_info.pop("type", None)
232
+ process_info.pop("video", None)
233
+ images = [
234
+ fetch_image({"image": video_element, **process_info}, size_factor=image_factor)
235
+ for video_element in ele["video"]
236
+ ]
237
+ nframes = ceil_by_factor(len(images), FRAME_FACTOR)
238
+ if len(images) < nframes:
239
+ images.extend([images[-1]] * (nframes - len(images)))
240
+ return images, None
241
+
242
+
243
+ def extract_vision_info(conversations: list[dict] | list[list[dict]]) -> list[dict]:
244
+ vision_infos = []
245
+ if isinstance(conversations[0], dict):
246
+ conversations = [conversations]
247
+ for conversation in conversations:
248
+ for message in conversation:
249
+ if isinstance(message["content"], list):
250
+ for ele in message["content"]:
251
+ if (
252
+ "image" in ele
253
+ or "image_url" in ele
254
+ or "video" in ele
255
+ or ele["type"] in ("image", "image_url", "video")
256
+ ):
257
+ vision_infos.append(ele)
258
+ return vision_infos
259
+
260
+
261
+ def process_vision_info(
262
+ conversations: list[dict] | list[list[dict]],
263
+ ) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | None, list[dict] | None]:
264
+ """Process vision info from conversation messages, loading images and videos."""
265
+ vision_infos = extract_vision_info(conversations)
266
+ image_inputs = []
267
+ video_inputs = []
268
+ video_metadata_list = []
269
+ for vision_info in vision_infos:
270
+ if "image" in vision_info or "image_url" in vision_info:
271
+ image_inputs.append(fetch_image(vision_info))
272
+ elif "video" in vision_info:
273
+ video, metadata = fetch_video(vision_info)
274
+ video_inputs.append(video)
275
+ video_metadata_list.append(metadata)
276
+ else:
277
+ raise ValueError("image, image_url or video should in content.")
278
+ if len(image_inputs) == 0:
279
+ image_inputs = None
280
+ if len(video_inputs) == 0:
281
+ video_inputs = None
282
+ video_metadata_list = None
283
+ return image_inputs, video_inputs, video_metadata_list