Spaces:
Running
Running
Add VEFX-Bench reference code
Browse files- .gitignore +37 -0
- LICENSE +176 -0
- examples/batch_scoring.py +51 -0
- examples/multi_gpu_scoring.py +119 -0
- examples/quick_start.py +39 -0
- examples/sample_videos/edited.mp4 +0 -0
- examples/sample_videos/original.mp4 +0 -0
- requirements.txt +11 -0
- scripts/prepare_and_upload.py +214 -0
- setup.py +30 -0
- vefx_reward/__init__.py +14 -0
- vefx_reward/inference.py +226 -0
- vefx_reward/model.py +149 -0
- vefx_reward/prompt_template.py +95 -0
- vefx_reward/vision_process.py +283 -0
.gitignore
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*.egg-info/
|
| 5 |
+
dist/
|
| 6 |
+
build/
|
| 7 |
+
*.egg
|
| 8 |
+
|
| 9 |
+
# Environment
|
| 10 |
+
.env
|
| 11 |
+
*.env
|
| 12 |
+
|
| 13 |
+
# IDE
|
| 14 |
+
.vscode/
|
| 15 |
+
.idea/
|
| 16 |
+
|
| 17 |
+
# Model weights
|
| 18 |
+
*.safetensors
|
| 19 |
+
*.pth
|
| 20 |
+
*.bin
|
| 21 |
+
*.ckpt
|
| 22 |
+
*.onnx
|
| 23 |
+
|
| 24 |
+
# Data
|
| 25 |
+
*.avi
|
| 26 |
+
*.mov
|
| 27 |
+
|
| 28 |
+
# Keep sample videos for examples
|
| 29 |
+
!examples/sample_videos/*.mp4
|
| 30 |
+
|
| 31 |
+
# OS
|
| 32 |
+
.DS_Store
|
| 33 |
+
Thumbs.db
|
| 34 |
+
|
| 35 |
+
# Outputs
|
| 36 |
+
merged_model/
|
| 37 |
+
results/
|
LICENSE
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to the Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by the Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding any notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
examples/batch_scoring.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Batch scoring: Evaluate multiple video edits from a CSV file.
|
| 3 |
+
|
| 4 |
+
Expected CSV format:
|
| 5 |
+
original_video,edited_video,instruction
|
| 6 |
+
path/to/orig1.mp4,path/to/edit1.mp4,"make it snowy"
|
| 7 |
+
path/to/orig2.mp4,path/to/edit2.mp4,"add a red hat"
|
| 8 |
+
|
| 9 |
+
Usage:
|
| 10 |
+
python examples/batch_scoring.py \
|
| 11 |
+
--csv edits.csv \
|
| 12 |
+
--output results.csv
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
import csv
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
from vefx_reward import VEFXReward
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def main():
|
| 23 |
+
parser = argparse.ArgumentParser(description="Batch score video edits")
|
| 24 |
+
parser.add_argument("--csv", required=True, help="Input CSV with columns: original_video, edited_video, instruction")
|
| 25 |
+
parser.add_argument("--output", default="results.csv", help="Output CSV path")
|
| 26 |
+
parser.add_argument("--model", default="VEFX-Reward/VEFX-Reward-4B")
|
| 27 |
+
parser.add_argument("--device", default="cuda")
|
| 28 |
+
args = parser.parse_args()
|
| 29 |
+
|
| 30 |
+
model = VEFXReward(args.model, device=args.device)
|
| 31 |
+
|
| 32 |
+
with open(args.csv) as f:
|
| 33 |
+
rows = list(csv.DictReader(f))
|
| 34 |
+
print(f"Loaded {len(rows)} samples from {args.csv}")
|
| 35 |
+
|
| 36 |
+
results = []
|
| 37 |
+
for i, row in enumerate(rows):
|
| 38 |
+
scores = model.score(row["original_video"], row["edited_video"], row["instruction"])
|
| 39 |
+
results.append({**row, **scores})
|
| 40 |
+
print(f"[{i+1}/{len(rows)}] IF={scores['IF']:.2f} RQ={scores['RQ']:.2f} EE={scores['EE']:.2f} Overall={scores['Overall']:.2f}")
|
| 41 |
+
|
| 42 |
+
fieldnames = list(rows[0].keys()) + ["IF", "RQ", "EE", "Overall"]
|
| 43 |
+
with open(args.output, "w", newline="") as f:
|
| 44 |
+
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
| 45 |
+
writer.writeheader()
|
| 46 |
+
writer.writerows(results)
|
| 47 |
+
print(f"\nResults saved to {args.output}")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
if __name__ == "__main__":
|
| 51 |
+
main()
|
examples/multi_gpu_scoring.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Multi-GPU parallel scoring using subprocess workers.
|
| 3 |
+
|
| 4 |
+
Splits a CSV of video edits across multiple GPUs for faster inference.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
python examples/multi_gpu_scoring.py \
|
| 8 |
+
--csv edits.csv \
|
| 9 |
+
--output results.csv \
|
| 10 |
+
--num_gpus 4
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import argparse
|
| 14 |
+
import csv
|
| 15 |
+
import json
|
| 16 |
+
import os
|
| 17 |
+
import subprocess
|
| 18 |
+
import sys
|
| 19 |
+
import tempfile
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def worker_main(args):
|
| 23 |
+
"""Single-GPU worker: load model, score shard, write results."""
|
| 24 |
+
import torch
|
| 25 |
+
from vefx_reward import VEFXReward
|
| 26 |
+
|
| 27 |
+
with open(args.shard_file) as f:
|
| 28 |
+
shard = json.load(f)
|
| 29 |
+
|
| 30 |
+
model = VEFXReward(args.model, device="cuda:0")
|
| 31 |
+
|
| 32 |
+
results = []
|
| 33 |
+
for i, item in enumerate(shard):
|
| 34 |
+
try:
|
| 35 |
+
scores = model.score(item["original_video"], item["edited_video"], item["instruction"])
|
| 36 |
+
results.append({**item, **scores})
|
| 37 |
+
print(f"[GPU {args.gpu_id}] [{i+1}/{len(shard)}] "
|
| 38 |
+
f"IF={scores['IF']:.2f} RQ={scores['RQ']:.2f} EE={scores['EE']:.2f}", flush=True)
|
| 39 |
+
except Exception as e:
|
| 40 |
+
print(f"[GPU {args.gpu_id}] [{i+1}/{len(shard)}] ERROR: {e}", flush=True)
|
| 41 |
+
results.append({**item, "IF": None, "RQ": None, "EE": None, "Overall": None, "error": str(e)})
|
| 42 |
+
|
| 43 |
+
with open(args.output_file, "w") as f:
|
| 44 |
+
json.dump(results, f)
|
| 45 |
+
print(f"[GPU {args.gpu_id}] Done — {len(results)} results", flush=True)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def main():
|
| 49 |
+
parser = argparse.ArgumentParser(description="Multi-GPU video edit scoring")
|
| 50 |
+
parser.add_argument("--csv", required=True, help="Input CSV")
|
| 51 |
+
parser.add_argument("--output", default="results.csv", help="Output CSV")
|
| 52 |
+
parser.add_argument("--model", default="VEFX-Reward/VEFX-Reward-4B")
|
| 53 |
+
parser.add_argument("--num_gpus", type=int, default=4)
|
| 54 |
+
# Internal worker args
|
| 55 |
+
parser.add_argument("--_worker", action="store_true", help=argparse.SUPPRESS)
|
| 56 |
+
parser.add_argument("--gpu_id", type=int, default=0, help=argparse.SUPPRESS)
|
| 57 |
+
parser.add_argument("--shard_file", type=str, default="", help=argparse.SUPPRESS)
|
| 58 |
+
parser.add_argument("--output_file", type=str, default="", help=argparse.SUPPRESS)
|
| 59 |
+
args = parser.parse_args()
|
| 60 |
+
|
| 61 |
+
if args._worker:
|
| 62 |
+
worker_main(args)
|
| 63 |
+
return
|
| 64 |
+
|
| 65 |
+
# --- Launcher mode ---
|
| 66 |
+
with open(args.csv) as f:
|
| 67 |
+
rows = list(csv.DictReader(f))
|
| 68 |
+
print(f"Loaded {len(rows)} samples, distributing across {args.num_gpus} GPUs")
|
| 69 |
+
|
| 70 |
+
items = [dict(row) for row in rows]
|
| 71 |
+
shards = [[] for _ in range(args.num_gpus)]
|
| 72 |
+
for i, item in enumerate(items):
|
| 73 |
+
shards[i % args.num_gpus].append(item)
|
| 74 |
+
|
| 75 |
+
tmpdir = tempfile.mkdtemp(prefix="vefx_multi_")
|
| 76 |
+
script = os.path.abspath(__file__)
|
| 77 |
+
procs = []
|
| 78 |
+
for gid in range(args.num_gpus):
|
| 79 |
+
if not shards[gid]:
|
| 80 |
+
continue
|
| 81 |
+
sf = os.path.join(tmpdir, f"shard_{gid}.json")
|
| 82 |
+
of = os.path.join(tmpdir, f"result_{gid}.json")
|
| 83 |
+
with open(sf, "w") as f:
|
| 84 |
+
json.dump(shards[gid], f)
|
| 85 |
+
env = os.environ.copy()
|
| 86 |
+
env["CUDA_VISIBLE_DEVICES"] = str(gid)
|
| 87 |
+
env["TOKENIZERS_PARALLELISM"] = "false"
|
| 88 |
+
p = subprocess.Popen(
|
| 89 |
+
[sys.executable, script,
|
| 90 |
+
"--_worker", "--gpu_id", str(gid),
|
| 91 |
+
"--shard_file", sf, "--output_file", of,
|
| 92 |
+
"--model", args.model],
|
| 93 |
+
env=env, stdout=sys.stdout, stderr=sys.stderr,
|
| 94 |
+
)
|
| 95 |
+
procs.append((p, of))
|
| 96 |
+
|
| 97 |
+
for p, _ in procs:
|
| 98 |
+
p.wait()
|
| 99 |
+
|
| 100 |
+
# Merge results
|
| 101 |
+
all_results = []
|
| 102 |
+
for _, of in procs:
|
| 103 |
+
if os.path.exists(of):
|
| 104 |
+
with open(of) as f:
|
| 105 |
+
all_results.extend(json.load(f))
|
| 106 |
+
|
| 107 |
+
fieldnames = list(rows[0].keys()) + ["IF", "RQ", "EE", "Overall"]
|
| 108 |
+
with open(args.output, "w", newline="") as f:
|
| 109 |
+
writer = csv.DictWriter(f, fieldnames=fieldnames, extrasaction="ignore")
|
| 110 |
+
writer.writeheader()
|
| 111 |
+
writer.writerows(all_results)
|
| 112 |
+
print(f"\nAll done — {len(all_results)} results saved to {args.output}")
|
| 113 |
+
|
| 114 |
+
import shutil
|
| 115 |
+
shutil.rmtree(tmpdir, ignore_errors=True)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
if __name__ == "__main__":
|
| 119 |
+
main()
|
examples/quick_start.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Quick start: Score a single video edit with VEFX-Reward.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
python examples/quick_start.py \
|
| 6 |
+
--original path/to/original.mp4 \
|
| 7 |
+
--edited path/to/edited.mp4 \
|
| 8 |
+
--instruction "add a hat to the person"
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import argparse
|
| 12 |
+
import torch
|
| 13 |
+
from vefx_reward import VEFXReward
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def main():
|
| 17 |
+
parser = argparse.ArgumentParser(description="Score a video edit with VEFX-Reward")
|
| 18 |
+
parser.add_argument("--original", required=True, help="Path to original video")
|
| 19 |
+
parser.add_argument("--edited", required=True, help="Path to edited video")
|
| 20 |
+
parser.add_argument("--instruction", required=True, help="Editing instruction")
|
| 21 |
+
parser.add_argument("--model", default="VEFX-Reward/VEFX-Reward-4B", help="Model path or HF ID")
|
| 22 |
+
parser.add_argument("--device", default="cuda", help="Device (cuda / cpu)")
|
| 23 |
+
args = parser.parse_args()
|
| 24 |
+
|
| 25 |
+
model = VEFXReward(args.model, device=args.device)
|
| 26 |
+
scores = model.score(args.original, args.edited, args.instruction)
|
| 27 |
+
|
| 28 |
+
print("\n" + "=" * 50)
|
| 29 |
+
print("VEFX-Reward Scores")
|
| 30 |
+
print("=" * 50)
|
| 31 |
+
print(f" Instructional Following (IF): {scores['IF']:.2f}")
|
| 32 |
+
print(f" Render Quality (RQ): {scores['RQ']:.2f}")
|
| 33 |
+
print(f" Edit Exclusivity (EE): {scores['EE']:.2f}")
|
| 34 |
+
print(f" Overall : {scores['Overall']:.2f}")
|
| 35 |
+
print("=" * 50)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
if __name__ == "__main__":
|
| 39 |
+
main()
|
examples/sample_videos/edited.mp4
ADDED
|
Binary file (21.5 kB). View file
|
|
|
examples/sample_videos/original.mp4
ADDED
|
Binary file (21.4 kB). View file
|
|
|
requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.1.0
|
| 2 |
+
torchvision>=0.16.0
|
| 3 |
+
transformers>=4.51.0
|
| 4 |
+
accelerate>=0.30.0
|
| 5 |
+
safetensors>=0.4.0
|
| 6 |
+
huggingface_hub>=0.20.0
|
| 7 |
+
Pillow>=10.0.0
|
| 8 |
+
numpy>=1.24.0
|
| 9 |
+
requests
|
| 10 |
+
packaging
|
| 11 |
+
decord>=0.6.0
|
scripts/prepare_and_upload.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Prepare and upload VEFX-Reward model to HuggingFace Hub.
|
| 3 |
+
|
| 4 |
+
This script:
|
| 5 |
+
1. Loads the Qwen3-VL-4B base model
|
| 6 |
+
2. Manually merges LoRA weights into the base model
|
| 7 |
+
3. Loads non-LoRA weights (rm_head, merger, special token embeddings)
|
| 8 |
+
4. Saves and uploads the complete model to HuggingFace
|
| 9 |
+
|
| 10 |
+
Prerequisites:
|
| 11 |
+
pip install huggingface_hub safetensors
|
| 12 |
+
huggingface-cli login
|
| 13 |
+
|
| 14 |
+
Usage:
|
| 15 |
+
python scripts/prepare_and_upload.py \
|
| 16 |
+
--checkpoint_dir /path/to/training/logs/v4/ord_4B_lora_2stage_promptv2_res399k \
|
| 17 |
+
--checkpoint_step 1050 \
|
| 18 |
+
--hf_repo VEFX-Reward/VEFX-Reward-4B \
|
| 19 |
+
--output_dir ./merged_model
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import argparse
|
| 23 |
+
import json
|
| 24 |
+
import os
|
| 25 |
+
|
| 26 |
+
import safetensors.torch as st
|
| 27 |
+
import torch
|
| 28 |
+
from transformers import AutoProcessor, AutoTokenizer
|
| 29 |
+
|
| 30 |
+
import sys
|
| 31 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
| 32 |
+
from vefx_reward.model import Qwen3VLRewardModelBT
|
| 33 |
+
|
| 34 |
+
SPECIAL_TOKENS = [
|
| 35 |
+
"<|VQ_reward|>", "<|MQ_reward|>", "<|TA_reward|>",
|
| 36 |
+
"<|IF_reward|>", "<|RQ_reward|>", "<|EE_reward|>",
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def main():
|
| 41 |
+
parser = argparse.ArgumentParser(description="Prepare and upload VEFX-Reward model")
|
| 42 |
+
parser.add_argument("--checkpoint_dir", required=True,
|
| 43 |
+
help="Training output directory containing model_config.json and checkpoint-*/")
|
| 44 |
+
parser.add_argument("--checkpoint_step", type=int, default=-1,
|
| 45 |
+
help="Checkpoint step to use (-1 = latest)")
|
| 46 |
+
parser.add_argument("--hf_repo", default="VEFX-Reward/VEFX-Reward-4B",
|
| 47 |
+
help="HuggingFace repo ID to upload to")
|
| 48 |
+
parser.add_argument("--output_dir", default="./merged_model",
|
| 49 |
+
help="Local directory to save merged model before upload")
|
| 50 |
+
parser.add_argument("--upload", action="store_true",
|
| 51 |
+
help="Actually upload to HuggingFace (otherwise just save locally)")
|
| 52 |
+
args = parser.parse_args()
|
| 53 |
+
|
| 54 |
+
# 1. Load training config
|
| 55 |
+
config_path = os.path.join(args.checkpoint_dir, "model_config.json")
|
| 56 |
+
with open(config_path) as f:
|
| 57 |
+
config_dict = json.load(f)
|
| 58 |
+
|
| 59 |
+
model_config = config_dict["model_config"]
|
| 60 |
+
data_config = config_dict["data_config"]
|
| 61 |
+
|
| 62 |
+
base_model_path = model_config["model_name_or_path"]
|
| 63 |
+
output_dim = model_config["output_dim"]
|
| 64 |
+
use_ordinal = model_config["use_ordinal"]
|
| 65 |
+
num_classes = model_config["num_classes"]
|
| 66 |
+
reward_token = model_config["reward_token"]
|
| 67 |
+
|
| 68 |
+
print(f"Base model: {base_model_path}")
|
| 69 |
+
print(f"Output dim: {output_dim}, Ordinal: {use_ordinal}, Num classes: {num_classes}")
|
| 70 |
+
|
| 71 |
+
# 2. Find checkpoint
|
| 72 |
+
import glob as globmod
|
| 73 |
+
ckpt_dirs = sorted(globmod.glob(os.path.join(args.checkpoint_dir, "checkpoint-*")),
|
| 74 |
+
key=lambda x: int(x.split("-")[-1]))
|
| 75 |
+
if args.checkpoint_step == -1:
|
| 76 |
+
ckpt_path = ckpt_dirs[-1]
|
| 77 |
+
else:
|
| 78 |
+
ckpt_path = os.path.join(args.checkpoint_dir, f"checkpoint-{args.checkpoint_step}")
|
| 79 |
+
print(f"Using checkpoint: {ckpt_path}")
|
| 80 |
+
|
| 81 |
+
# 3. Load processor from base model with checkpoint's tokenizer
|
| 82 |
+
processor = AutoProcessor.from_pretrained(base_model_path, padding_side="right")
|
| 83 |
+
ckpt_tokenizer_path = os.path.join(ckpt_path, "tokenizer")
|
| 84 |
+
if os.path.isdir(ckpt_tokenizer_path):
|
| 85 |
+
processor.tokenizer = AutoTokenizer.from_pretrained(ckpt_tokenizer_path)
|
| 86 |
+
else:
|
| 87 |
+
processor.tokenizer.add_special_tokens({"additional_special_tokens": SPECIAL_TOKENS})
|
| 88 |
+
special_token_ids = processor.tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS)
|
| 89 |
+
print(f"Tokenizer vocab size: {len(processor.tokenizer)}")
|
| 90 |
+
|
| 91 |
+
# 4. Load base model with reward head
|
| 92 |
+
print("Loading base model...")
|
| 93 |
+
model = Qwen3VLRewardModelBT.from_pretrained(
|
| 94 |
+
base_model_path,
|
| 95 |
+
torch_dtype=torch.bfloat16,
|
| 96 |
+
output_dim=output_dim,
|
| 97 |
+
reward_token=reward_token,
|
| 98 |
+
special_token_ids=special_token_ids,
|
| 99 |
+
use_ordinal=use_ordinal,
|
| 100 |
+
num_classes=num_classes,
|
| 101 |
+
use_cache=True,
|
| 102 |
+
device_map="cpu",
|
| 103 |
+
)
|
| 104 |
+
model.resize_token_embeddings(len(processor.tokenizer))
|
| 105 |
+
print(f"Model embeddings resized to {len(processor.tokenizer)}")
|
| 106 |
+
|
| 107 |
+
# 5. Manual LoRA merge (bypasses PEFT tie_word_embeddings issues)
|
| 108 |
+
print("Loading and merging adapter weights...")
|
| 109 |
+
adapter_weights = st.load_file(os.path.join(ckpt_path, "adapter_model.safetensors"), device="cpu")
|
| 110 |
+
|
| 111 |
+
with open(os.path.join(ckpt_path, "adapter_config.json")) as f:
|
| 112 |
+
lora_cfg = json.load(f)
|
| 113 |
+
scaling = lora_cfg["lora_alpha"] / lora_cfg["r"]
|
| 114 |
+
|
| 115 |
+
# Categorize adapter keys
|
| 116 |
+
base_layers, lora_As, lora_Bs, emb_As, emb_Bs = {}, {}, {}, {}, {}
|
| 117 |
+
for k, v in adapter_weights.items():
|
| 118 |
+
ck = k.replace("base_model.model.", "")
|
| 119 |
+
if ".base_layer.weight" in ck:
|
| 120 |
+
base_layers[ck.replace(".base_layer.weight", "")] = v
|
| 121 |
+
elif ".lora_A.weight" in ck:
|
| 122 |
+
lora_As[ck.replace(".lora_A.weight", "")] = v
|
| 123 |
+
elif ".lora_B.weight" in ck:
|
| 124 |
+
lora_Bs[ck.replace(".lora_B.weight", "")] = v
|
| 125 |
+
elif ".lora_embedding_A" in ck:
|
| 126 |
+
emb_As[ck.replace(".lora_embedding_A", "")] = v
|
| 127 |
+
elif ".lora_embedding_B" in ck:
|
| 128 |
+
emb_Bs[ck.replace(".lora_embedding_B", "")] = v
|
| 129 |
+
|
| 130 |
+
model_state = model.state_dict()
|
| 131 |
+
|
| 132 |
+
# Replace base layer weights (for resized lm_head / embed_tokens)
|
| 133 |
+
for mod, w in base_layers.items():
|
| 134 |
+
key = mod + ".weight"
|
| 135 |
+
if key in model_state:
|
| 136 |
+
model_state[key] = w.to(model_state[key].dtype)
|
| 137 |
+
print(f" Replaced base layer: {key}")
|
| 138 |
+
|
| 139 |
+
# Merge LoRA: W_merged = W + B @ A * scaling
|
| 140 |
+
merged_count = 0
|
| 141 |
+
for mod in lora_As:
|
| 142 |
+
if mod in lora_Bs:
|
| 143 |
+
A, B = lora_As[mod].float(), lora_Bs[mod].float()
|
| 144 |
+
delta = (B @ A) * scaling
|
| 145 |
+
key = mod + ".weight"
|
| 146 |
+
if key in model_state:
|
| 147 |
+
model_state[key] = (model_state[key].float() + delta).to(torch.bfloat16)
|
| 148 |
+
merged_count += 1
|
| 149 |
+
print(f" Merged {merged_count} LoRA modules")
|
| 150 |
+
|
| 151 |
+
# Merge embedding LoRA
|
| 152 |
+
for mod in emb_As:
|
| 153 |
+
if mod in emb_Bs:
|
| 154 |
+
A, B = emb_As[mod].float(), emb_Bs[mod].float()
|
| 155 |
+
delta = (B @ A).T * scaling
|
| 156 |
+
key = mod + ".weight"
|
| 157 |
+
if key in model_state:
|
| 158 |
+
model_state[key] = (model_state[key].float() + delta).to(torch.bfloat16)
|
| 159 |
+
print(f" Merged embedding LoRA: {key}")
|
| 160 |
+
|
| 161 |
+
# 6. Load non-LoRA weights (rm_head, merger, special embeddings)
|
| 162 |
+
non_lora_path = os.path.join(ckpt_path, "non_lora_state_dict.pth")
|
| 163 |
+
if os.path.exists(non_lora_path):
|
| 164 |
+
print("Loading non-LoRA weights...")
|
| 165 |
+
non_lora_weights = torch.load(non_lora_path, map_location="cpu")
|
| 166 |
+
for k, v in non_lora_weights.items():
|
| 167 |
+
ck = k.replace("base_model.model.", "")
|
| 168 |
+
if ck in model_state:
|
| 169 |
+
model_state[ck] = v.to(model_state[ck].dtype)
|
| 170 |
+
print(f" Loaded: {ck}")
|
| 171 |
+
|
| 172 |
+
model.load_state_dict(model_state)
|
| 173 |
+
print(f"All weights loaded. rm_head shape: {model.rm_head.weight.shape}")
|
| 174 |
+
|
| 175 |
+
# 7. Save merged model
|
| 176 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 177 |
+
print(f"Saving merged model to {args.output_dir}...")
|
| 178 |
+
model.save_pretrained(args.output_dir, safe_serialization=True)
|
| 179 |
+
processor.save_pretrained(args.output_dir)
|
| 180 |
+
|
| 181 |
+
# Save VEFX-specific config
|
| 182 |
+
vefx_config = {
|
| 183 |
+
"output_dim": output_dim,
|
| 184 |
+
"use_ordinal": use_ordinal,
|
| 185 |
+
"num_classes": num_classes,
|
| 186 |
+
"reward_token": reward_token,
|
| 187 |
+
"fps": data_config.get("fps", 4.0),
|
| 188 |
+
"max_frame_pixels": data_config.get("max_frame_pixels", 399360),
|
| 189 |
+
"eval_dim": data_config.get("eval_dim", ["IF", "RQ", "EE"]),
|
| 190 |
+
"prompt_template_type": data_config.get("prompt_template_type", "editreward_v2_special"),
|
| 191 |
+
}
|
| 192 |
+
with open(os.path.join(args.output_dir, "vefx_config.json"), "w") as f:
|
| 193 |
+
json.dump(vefx_config, f, indent=2)
|
| 194 |
+
print("Saved vefx_config.json")
|
| 195 |
+
|
| 196 |
+
# 8. Upload to HuggingFace
|
| 197 |
+
if args.upload:
|
| 198 |
+
from huggingface_hub import HfApi
|
| 199 |
+
api = HfApi()
|
| 200 |
+
print(f"Uploading to {args.hf_repo}...")
|
| 201 |
+
api.upload_folder(
|
| 202 |
+
folder_path=args.output_dir,
|
| 203 |
+
repo_id=args.hf_repo,
|
| 204 |
+
repo_type="model",
|
| 205 |
+
)
|
| 206 |
+
print(f"Upload complete: https://huggingface.co/{args.hf_repo}")
|
| 207 |
+
else:
|
| 208 |
+
print(f"\nModel saved to {args.output_dir}")
|
| 209 |
+
print(f"To upload, run again with --upload flag, or manually:")
|
| 210 |
+
print(f" huggingface-cli upload {args.hf_repo} {args.output_dir} .")
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
if __name__ == "__main__":
|
| 214 |
+
main()
|
setup.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from setuptools import setup, find_packages
|
| 2 |
+
|
| 3 |
+
setup(
|
| 4 |
+
name="vefx-reward",
|
| 5 |
+
version="0.1.0",
|
| 6 |
+
description="VEFX-Reward: A reward model for video editing quality assessment",
|
| 7 |
+
long_description=open("README.md").read(),
|
| 8 |
+
long_description_content_type="text/markdown",
|
| 9 |
+
url="",
|
| 10 |
+
packages=find_packages(),
|
| 11 |
+
python_requires=">=3.10",
|
| 12 |
+
install_requires=[
|
| 13 |
+
"torch>=2.1.0",
|
| 14 |
+
"torchvision>=0.16.0",
|
| 15 |
+
"transformers>=4.51.0",
|
| 16 |
+
"accelerate>=0.30.0",
|
| 17 |
+
"safetensors>=0.4.0",
|
| 18 |
+
"huggingface_hub>=0.20.0",
|
| 19 |
+
"Pillow>=10.0.0",
|
| 20 |
+
"numpy>=1.24.0",
|
| 21 |
+
"requests",
|
| 22 |
+
"packaging",
|
| 23 |
+
"decord>=0.6.0",
|
| 24 |
+
],
|
| 25 |
+
classifiers=[
|
| 26 |
+
"Programming Language :: Python :: 3",
|
| 27 |
+
"License :: OSI Approved :: Apache Software License",
|
| 28 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
| 29 |
+
],
|
| 30 |
+
)
|
vefx_reward/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
VEFX-Reward: A reward model for video editing quality assessment.
|
| 3 |
+
|
| 4 |
+
Evaluates video edits on three dimensions (1–4 scale):
|
| 5 |
+
- IF (Instructional Following)
|
| 6 |
+
- RQ (Render Quality)
|
| 7 |
+
- EE (Edit Exclusivity)
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
__version__ = "0.1.0"
|
| 11 |
+
|
| 12 |
+
from .inference import VEFXReward
|
| 13 |
+
|
| 14 |
+
__all__ = ["VEFXReward"]
|
vefx_reward/inference.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
VEFX-Reward: Video editing quality assessment inference API.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
from vefx_reward import VEFXReward
|
| 6 |
+
|
| 7 |
+
model = VEFXReward("VEFX-Reward/VEFX-Reward-4B", device="cuda")
|
| 8 |
+
scores = model.score("original.mp4", "edited.mp4", "add a hat to the person")
|
| 9 |
+
# {'IF': 3.21, 'RQ': 2.85, 'EE': 3.54, 'Overall': 9.60}
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import json
|
| 13 |
+
import os
|
| 14 |
+
from collections.abc import Mapping
|
| 15 |
+
from typing import Optional
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch
|
| 19 |
+
from transformers import AutoProcessor
|
| 20 |
+
|
| 21 |
+
from .model import Qwen3VLRewardModelBT, ordinal_predict
|
| 22 |
+
from .prompt_template import build_prompt
|
| 23 |
+
from .vision_process import process_vision_info
|
| 24 |
+
|
| 25 |
+
# Default model hyperparameters (matching the released VEFX-Reward-4B)
|
| 26 |
+
DEFAULT_FPS = 4.0
|
| 27 |
+
DEFAULT_MAX_FRAME_PIXELS = 399360
|
| 28 |
+
DEFAULT_NUM_CLASSES = 4
|
| 29 |
+
DEFAULT_OUTPUT_DIM = 3
|
| 30 |
+
DIMS = ["IF", "RQ", "EE"]
|
| 31 |
+
|
| 32 |
+
SPECIAL_TOKENS = [
|
| 33 |
+
"<|VQ_reward|>", "<|MQ_reward|>", "<|TA_reward|>",
|
| 34 |
+
"<|IF_reward|>", "<|RQ_reward|>", "<|EE_reward|>",
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class VEFXReward:
|
| 39 |
+
"""VEFX-Reward model for video editing quality assessment.
|
| 40 |
+
|
| 41 |
+
Scores video edits on three dimensions (1–4 scale):
|
| 42 |
+
- **IF** (Instructional Following): How well the edit follows the instruction.
|
| 43 |
+
- **RQ** (Render Quality): Visual and temporal quality of the edited video.
|
| 44 |
+
- **EE** (Edit Exclusivity): Whether only the intended region was modified.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
model_path: HuggingFace model ID or local path
|
| 48 |
+
(e.g., ``"VEFX-Reward/VEFX-Reward-4B"``).
|
| 49 |
+
device: Device string (default ``"cuda"``).
|
| 50 |
+
dtype: Torch dtype (default ``torch.bfloat16``).
|
| 51 |
+
fps: Frames per second for video sampling (default 4.0).
|
| 52 |
+
max_frame_pixels: Maximum pixels per frame (default 399360).
|
| 53 |
+
|
| 54 |
+
Example::
|
| 55 |
+
|
| 56 |
+
model = VEFXReward("VEFX-Reward/VEFX-Reward-4B")
|
| 57 |
+
scores = model.score("original.mp4", "edited.mp4", "make it snowy")
|
| 58 |
+
print(scores)
|
| 59 |
+
# {'IF': 3.2, 'RQ': 2.9, 'EE': 3.5, 'Overall': 9.6}
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
def __init__(
|
| 63 |
+
self,
|
| 64 |
+
model_path: str = "VEFX-Reward/VEFX-Reward-4B",
|
| 65 |
+
device: str = "cuda",
|
| 66 |
+
dtype: torch.dtype = torch.bfloat16,
|
| 67 |
+
fps: float = DEFAULT_FPS,
|
| 68 |
+
max_frame_pixels: int = DEFAULT_MAX_FRAME_PIXELS,
|
| 69 |
+
):
|
| 70 |
+
self.device = device
|
| 71 |
+
self.dtype = dtype
|
| 72 |
+
self.fps = fps
|
| 73 |
+
self.max_frame_pixels = max_frame_pixels
|
| 74 |
+
|
| 75 |
+
# Load config
|
| 76 |
+
vefx_config_path = os.path.join(model_path, "vefx_config.json") if os.path.isdir(model_path) else None
|
| 77 |
+
if vefx_config_path and os.path.exists(vefx_config_path):
|
| 78 |
+
with open(vefx_config_path) as f:
|
| 79 |
+
vefx_config = json.load(f)
|
| 80 |
+
else:
|
| 81 |
+
# Try to download from HF hub
|
| 82 |
+
try:
|
| 83 |
+
from huggingface_hub import hf_hub_download
|
| 84 |
+
vefx_config_path = hf_hub_download(model_path, "vefx_config.json")
|
| 85 |
+
with open(vefx_config_path) as f:
|
| 86 |
+
vefx_config = json.load(f)
|
| 87 |
+
except Exception:
|
| 88 |
+
vefx_config = {}
|
| 89 |
+
|
| 90 |
+
self.num_classes = vefx_config.get("num_classes", DEFAULT_NUM_CLASSES)
|
| 91 |
+
self.output_dim = vefx_config.get("output_dim", DEFAULT_OUTPUT_DIM)
|
| 92 |
+
self.use_ordinal = vefx_config.get("use_ordinal", True)
|
| 93 |
+
reward_token = vefx_config.get("reward_token", "special")
|
| 94 |
+
|
| 95 |
+
# Load processor and add special tokens
|
| 96 |
+
self.processor = AutoProcessor.from_pretrained(model_path, padding_side="right")
|
| 97 |
+
existing_tokens = set(self.processor.tokenizer.get_vocab().keys())
|
| 98 |
+
tokens_to_add = [t for t in SPECIAL_TOKENS if t not in existing_tokens]
|
| 99 |
+
if tokens_to_add:
|
| 100 |
+
self.processor.tokenizer.add_special_tokens({"additional_special_tokens": tokens_to_add})
|
| 101 |
+
special_token_ids = self.processor.tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS)
|
| 102 |
+
|
| 103 |
+
# Load model
|
| 104 |
+
self.model = Qwen3VLRewardModelBT.from_pretrained(
|
| 105 |
+
model_path,
|
| 106 |
+
torch_dtype=dtype,
|
| 107 |
+
output_dim=self.output_dim,
|
| 108 |
+
reward_token=reward_token,
|
| 109 |
+
special_token_ids=special_token_ids,
|
| 110 |
+
use_ordinal=self.use_ordinal,
|
| 111 |
+
num_classes=self.num_classes,
|
| 112 |
+
use_cache=True,
|
| 113 |
+
)
|
| 114 |
+
self.model.resize_token_embeddings(len(self.processor.tokenizer))
|
| 115 |
+
|
| 116 |
+
self.model.eval().to(self.device)
|
| 117 |
+
print(f"VEFX-Reward loaded on {self.device} ({dtype})")
|
| 118 |
+
|
| 119 |
+
def _prepare_input(self, data):
|
| 120 |
+
if isinstance(data, Mapping):
|
| 121 |
+
return type(data)({k: self._prepare_input(v) for k, v in data.items()})
|
| 122 |
+
elif isinstance(data, (tuple, list)):
|
| 123 |
+
return type(data)(self._prepare_input(v) for v in data)
|
| 124 |
+
elif isinstance(data, torch.Tensor):
|
| 125 |
+
return data.to(device=self.device)
|
| 126 |
+
return data
|
| 127 |
+
|
| 128 |
+
def _build_batch(self, original_video: str, edited_video: str, instruction: str):
|
| 129 |
+
"""Build a single-sample batch from video paths and instruction."""
|
| 130 |
+
content = [
|
| 131 |
+
{
|
| 132 |
+
"type": "video",
|
| 133 |
+
"video": f"file://{os.path.abspath(original_video)}",
|
| 134 |
+
"max_pixels": self.max_frame_pixels,
|
| 135 |
+
"fps": self.fps,
|
| 136 |
+
"sample_type": "uniform",
|
| 137 |
+
},
|
| 138 |
+
{
|
| 139 |
+
"type": "video",
|
| 140 |
+
"video": f"file://{os.path.abspath(edited_video)}",
|
| 141 |
+
"max_pixels": self.max_frame_pixels,
|
| 142 |
+
"fps": self.fps,
|
| 143 |
+
"sample_type": "uniform",
|
| 144 |
+
},
|
| 145 |
+
{"type": "text", "text": build_prompt(instruction)},
|
| 146 |
+
]
|
| 147 |
+
messages = [[{"role": "user", "content": content}]]
|
| 148 |
+
image_inputs, video_inputs, video_metadata_list = process_vision_info(messages)
|
| 149 |
+
video_inputs = [v.float() / 255.0 for v in video_inputs]
|
| 150 |
+
|
| 151 |
+
texts = self.processor.apply_chat_template(
|
| 152 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 153 |
+
)
|
| 154 |
+
processor_kwargs = dict(
|
| 155 |
+
text=texts,
|
| 156 |
+
images=image_inputs,
|
| 157 |
+
videos=video_inputs,
|
| 158 |
+
padding=True,
|
| 159 |
+
return_tensors="pt",
|
| 160 |
+
videos_kwargs={"do_rescale": False, "do_sample_frames": False},
|
| 161 |
+
)
|
| 162 |
+
if video_metadata_list:
|
| 163 |
+
processor_kwargs["videos_kwargs"]["video_metadata"] = video_metadata_list
|
| 164 |
+
processor_kwargs["videos_kwargs"]["return_metadata"] = True
|
| 165 |
+
|
| 166 |
+
batch = self.processor(**processor_kwargs)
|
| 167 |
+
return self._prepare_input(batch)
|
| 168 |
+
|
| 169 |
+
def _logits_to_scores(self, logits: torch.Tensor) -> dict:
|
| 170 |
+
"""Convert raw ordinal logits to IF/RQ/EE scores."""
|
| 171 |
+
logits_np = logits.float().cpu().numpy()
|
| 172 |
+
if self.use_ordinal:
|
| 173 |
+
num_dims = self.output_dim
|
| 174 |
+
num_thresholds = self.num_classes - 1
|
| 175 |
+
logits_reshaped = logits_np.reshape(1, num_dims, num_thresholds)
|
| 176 |
+
hard, soft = ordinal_predict(logits_reshaped, self.num_classes)
|
| 177 |
+
scores = {DIMS[j]: round(float(soft[0, j]), 3) for j in range(num_dims)}
|
| 178 |
+
else:
|
| 179 |
+
scores = {DIMS[j]: round(float(logits_np[0, j]), 3) for j in range(self.output_dim)}
|
| 180 |
+
scores["Overall"] = round(sum(scores[d] for d in DIMS), 3)
|
| 181 |
+
return scores
|
| 182 |
+
|
| 183 |
+
@torch.no_grad()
|
| 184 |
+
def score(
|
| 185 |
+
self,
|
| 186 |
+
original_video: str,
|
| 187 |
+
edited_video: str,
|
| 188 |
+
instruction: str,
|
| 189 |
+
) -> dict:
|
| 190 |
+
"""Score a single video edit.
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
original_video: Path to the original (source) video.
|
| 194 |
+
edited_video: Path to the edited video.
|
| 195 |
+
instruction: The editing instruction text.
|
| 196 |
+
|
| 197 |
+
Returns:
|
| 198 |
+
Dictionary with keys ``'IF'``, ``'RQ'``, ``'EE'``, ``'Overall'``.
|
| 199 |
+
Each dimension is scored on a continuous 1–4 scale.
|
| 200 |
+
"""
|
| 201 |
+
batch = self._build_batch(original_video, edited_video, instruction)
|
| 202 |
+
logits = self.model(**batch, return_dict=True)["logits"]
|
| 203 |
+
return self._logits_to_scores(logits)
|
| 204 |
+
|
| 205 |
+
@torch.no_grad()
|
| 206 |
+
def score_batch(
|
| 207 |
+
self,
|
| 208 |
+
original_videos: list[str],
|
| 209 |
+
edited_videos: list[str],
|
| 210 |
+
instructions: list[str],
|
| 211 |
+
) -> list[dict]:
|
| 212 |
+
"""Score multiple video edits (processed sequentially to avoid OOM).
|
| 213 |
+
|
| 214 |
+
Args:
|
| 215 |
+
original_videos: List of paths to original videos.
|
| 216 |
+
edited_videos: List of paths to edited videos.
|
| 217 |
+
instructions: List of editing instruction texts.
|
| 218 |
+
|
| 219 |
+
Returns:
|
| 220 |
+
List of score dictionaries, one per sample.
|
| 221 |
+
"""
|
| 222 |
+
assert len(original_videos) == len(edited_videos) == len(instructions)
|
| 223 |
+
results = []
|
| 224 |
+
for orig, edit, inst in zip(original_videos, edited_videos, instructions):
|
| 225 |
+
results.append(self.score(orig, edit, inst))
|
| 226 |
+
return results
|
vefx_reward/model.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
VEFX-Reward: Qwen3-VL based reward model for video editing quality assessment.
|
| 3 |
+
|
| 4 |
+
Extends Qwen3VLForConditionalGeneration with an rm_head for ordinal regression,
|
| 5 |
+
scoring video edits on Instructional Following (IF), Render Quality (RQ),
|
| 6 |
+
and Edit Exclusivity (EE) on a 1–4 scale.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from typing import List, Optional
|
| 13 |
+
from transformers import Qwen3VLForConditionalGeneration
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Qwen3VLRewardModelBT(Qwen3VLForConditionalGeneration):
|
| 17 |
+
"""Qwen3-VL with a reward head for ordinal video edit quality scoring."""
|
| 18 |
+
|
| 19 |
+
def __init__(self, config, output_dim=3, reward_token="special",
|
| 20 |
+
special_token_ids=None, use_ordinal=True, num_classes=4, **kwargs):
|
| 21 |
+
if 'use_cache' in kwargs:
|
| 22 |
+
config.use_cache = kwargs.pop('use_cache')
|
| 23 |
+
super().__init__(config, **kwargs)
|
| 24 |
+
self.output_dim = output_dim
|
| 25 |
+
self.rm_head = nn.Linear(config.text_config.hidden_size, output_dim, bias=False)
|
| 26 |
+
nn.init.normal_(self.rm_head.weight, mean=0.0, std=1.0 / config.text_config.hidden_size)
|
| 27 |
+
self.reward_token = reward_token
|
| 28 |
+
self.use_ordinal = use_ordinal
|
| 29 |
+
self.num_classes = num_classes
|
| 30 |
+
self.special_token_ids = special_token_ids
|
| 31 |
+
if self.special_token_ids is not None:
|
| 32 |
+
self.reward_token = "special"
|
| 33 |
+
|
| 34 |
+
def forward(
|
| 35 |
+
self,
|
| 36 |
+
input_ids: torch.LongTensor = None,
|
| 37 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 38 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 39 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 40 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 41 |
+
labels: Optional[torch.LongTensor] = None,
|
| 42 |
+
use_cache: Optional[bool] = None,
|
| 43 |
+
output_attentions: Optional[bool] = None,
|
| 44 |
+
output_hidden_states: Optional[bool] = None,
|
| 45 |
+
return_dict: Optional[bool] = None,
|
| 46 |
+
pixel_values: Optional[torch.Tensor] = None,
|
| 47 |
+
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
| 48 |
+
image_grid_thw: Optional[torch.LongTensor] = None,
|
| 49 |
+
video_grid_thw: Optional[torch.LongTensor] = None,
|
| 50 |
+
mm_token_type_ids: Optional[torch.IntTensor] = None,
|
| 51 |
+
**kwargs,
|
| 52 |
+
):
|
| 53 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 54 |
+
output_hidden_states = (
|
| 55 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 56 |
+
)
|
| 57 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 58 |
+
|
| 59 |
+
outputs = self.model(
|
| 60 |
+
input_ids=input_ids,
|
| 61 |
+
position_ids=position_ids,
|
| 62 |
+
attention_mask=attention_mask,
|
| 63 |
+
past_key_values=past_key_values,
|
| 64 |
+
inputs_embeds=inputs_embeds,
|
| 65 |
+
pixel_values=pixel_values,
|
| 66 |
+
pixel_values_videos=pixel_values_videos,
|
| 67 |
+
image_grid_thw=image_grid_thw,
|
| 68 |
+
video_grid_thw=video_grid_thw,
|
| 69 |
+
mm_token_type_ids=mm_token_type_ids,
|
| 70 |
+
output_attentions=output_attentions,
|
| 71 |
+
output_hidden_states=output_hidden_states,
|
| 72 |
+
return_dict=return_dict,
|
| 73 |
+
**kwargs,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
hidden_states = outputs[0] # [B, L, D]
|
| 77 |
+
logits = self.rm_head(hidden_states) # [B, L, output_dim]
|
| 78 |
+
|
| 79 |
+
if input_ids is not None:
|
| 80 |
+
batch_size = input_ids.shape[0]
|
| 81 |
+
else:
|
| 82 |
+
batch_size = inputs_embeds.shape[0]
|
| 83 |
+
|
| 84 |
+
pad_token_id = self.config.text_config.pad_token_id
|
| 85 |
+
if pad_token_id is None and batch_size != 1:
|
| 86 |
+
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
| 87 |
+
if pad_token_id is None:
|
| 88 |
+
sequence_lengths = -1
|
| 89 |
+
else:
|
| 90 |
+
if input_ids is not None:
|
| 91 |
+
sequence_lengths = torch.eq(input_ids, pad_token_id).int().argmax(-1) - 1
|
| 92 |
+
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
| 93 |
+
sequence_lengths = sequence_lengths.to(logits.device)
|
| 94 |
+
else:
|
| 95 |
+
sequence_lengths = -1
|
| 96 |
+
|
| 97 |
+
if self.reward_token == "last":
|
| 98 |
+
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
| 99 |
+
elif self.reward_token == "mean":
|
| 100 |
+
valid_lengths = torch.clamp(sequence_lengths, min=0, max=logits.size(1) - 1)
|
| 101 |
+
pooled_logits = torch.stack([logits[i, :valid_lengths[i]].mean(dim=0) for i in range(batch_size)])
|
| 102 |
+
elif self.reward_token == "special":
|
| 103 |
+
special_token_mask = torch.zeros_like(input_ids, dtype=torch.bool)
|
| 104 |
+
for special_token_id in self.special_token_ids:
|
| 105 |
+
special_token_mask = special_token_mask | (input_ids == special_token_id)
|
| 106 |
+
pooled_logits = logits[special_token_mask, ...]
|
| 107 |
+
num_matched = special_token_mask.sum(dim=1)
|
| 108 |
+
num_dims = num_matched[0].item()
|
| 109 |
+
pooled_logits = pooled_logits.view(batch_size, num_dims, -1)
|
| 110 |
+
if self.use_ordinal:
|
| 111 |
+
pooled_logits = pooled_logits.view(batch_size, -1)
|
| 112 |
+
else:
|
| 113 |
+
if self.output_dim == num_dims:
|
| 114 |
+
pooled_logits = pooled_logits.diagonal(dim1=1, dim2=2)
|
| 115 |
+
pooled_logits = pooled_logits.view(batch_size, -1)
|
| 116 |
+
else:
|
| 117 |
+
raise ValueError(f"Invalid reward_token: {self.reward_token}")
|
| 118 |
+
|
| 119 |
+
return {"logits": pooled_logits}
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def ordinal_predict(logits: np.ndarray, num_classes: int):
|
| 123 |
+
"""
|
| 124 |
+
Convert CORN ordinal logits to predicted scores.
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
logits: [B, D, K-1] raw threshold logits
|
| 128 |
+
num_classes: K (number of ordinal classes)
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
hard_preds: [B, D] integer predictions in {1..K}
|
| 132 |
+
soft_preds: [B, D] continuous expected value E[Y]
|
| 133 |
+
"""
|
| 134 |
+
probs = 1.0 / (1.0 + np.exp(-logits)) # sigmoid → P(Y>k | Y>=k)
|
| 135 |
+
cum_probs = np.cumprod(probs, axis=-1) # P(Y>k) = prod_{j<=k} P(Y>j|Y>=j)
|
| 136 |
+
|
| 137 |
+
hard_preds = (cum_probs > 0.5).sum(axis=-1) + 1 # [B, D]
|
| 138 |
+
|
| 139 |
+
cum_ext = np.concatenate([
|
| 140 |
+
np.ones((*cum_probs.shape[:-1], 1)),
|
| 141 |
+
cum_probs,
|
| 142 |
+
np.zeros((*cum_probs.shape[:-1], 1)),
|
| 143 |
+
], axis=-1)
|
| 144 |
+
p_class = cum_ext[..., :-1] - cum_ext[..., 1:]
|
| 145 |
+
p_class = np.maximum(p_class, 0)
|
| 146 |
+
class_values = np.arange(1, num_classes + 1)
|
| 147 |
+
soft_preds = (p_class * class_values).sum(axis=-1)
|
| 148 |
+
|
| 149 |
+
return hard_preds, soft_preds
|
vefx_reward/prompt_template.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Prompt templates for VEFX-Reward video editing quality evaluation.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
EDITREWARD_V2_SPECIAL = """You are an expert evaluator assessing the quality of AI-generated video edits. You will be provided with two videos:
|
| 6 |
+
- **Video 1**: The Original Video (before editing)
|
| 7 |
+
- **Video 2**: The Edited Video (after editing)
|
| 8 |
+
|
| 9 |
+
The editing instruction is:
|
| 10 |
+
"{text_prompt}"
|
| 11 |
+
|
| 12 |
+
Your task is to evaluate the Edited Video across THREE independent dimensions. Each dimension is scored on a 1–4 integer scale. **Scores across dimensions are independent** — a failure in one dimension must NOT affect scores in another.
|
| 13 |
+
|
| 14 |
+
---
|
| 15 |
+
|
| 16 |
+
## Dimension 1: Instructional Following (IF)
|
| 17 |
+
**Core question:** Does the edited video accurately reflect the semantic requirements of the editing instruction?
|
| 18 |
+
|
| 19 |
+
Evaluation criteria:
|
| 20 |
+
- Object replacement: If the instruction says "replace apple with orange," did the model actually generate an orange (not a lemon or tomato)?
|
| 21 |
+
- Action/attribute changes: If the instruction involves motion or attribute changes (e.g., "make it night"), was this correctly executed?
|
| 22 |
+
- Completeness: Were ALL parts of the instruction addressed, not just partial execution?
|
| 23 |
+
|
| 24 |
+
Scoring rubric:
|
| 25 |
+
- **4 (Perfect):** The edit precisely and completely executes all instructions. Object categories, attributes (color, shape), actions, and styles all match the instruction with no ambiguity.
|
| 26 |
+
- **3 (High):** The main instruction was executed, but minor details deviate. E.g., instruction asks for "red sports car" but a "red truck" was generated — the main concept "red car" is correct.
|
| 27 |
+
- **2 (Low):** The main instruction was partially executed but with significant deviations, or completely irrelevant modifications were made.
|
| 28 |
+
- **1 (Failed):** The edit has no relation to the instruction. E.g., instruction asks for "night scene" but the video remains daytime, or no change occurred at all.
|
| 29 |
+
|
| 30 |
+
**Important notes:**
|
| 31 |
+
- If the edit instruction asks for a camera perspective change (e.g., "shift to high angle") and the video shows no actual perspective change, score 1.
|
| 32 |
+
- If the instruction asks for adding/increasing objects and no new objects appear, score 1.
|
| 33 |
+
- A video that looks identical to the original (no edit happened) always scores 1 for IF.
|
| 34 |
+
|
| 35 |
+
Instructional Following score (integer 1-4): <|IF_reward|>
|
| 36 |
+
|
| 37 |
+
---
|
| 38 |
+
|
| 39 |
+
## Dimension 2: Render Quality (RQ)
|
| 40 |
+
**Core question:** What is the visual and temporal quality of the edited video?
|
| 41 |
+
|
| 42 |
+
Evaluation criteria:
|
| 43 |
+
- Naturalness and clarity: Are all parts of the video natural and sharp? Any blurriness, noise, or artifacts?
|
| 44 |
+
- Physical plausibility: Does object motion obey physics? Any flickering, jittering, objects disappearing/morphing unexpectedly?
|
| 45 |
+
- Temporal consistency: Is the video smooth frame-to-frame? Any sudden jumps, abrupt texture/color changes between frames?
|
| 46 |
+
|
| 47 |
+
Scoring rubric:
|
| 48 |
+
- **4 (Excellent):** Video clarity is very high with no visible defects, or only extremely minor artifacts detectable on very close inspection. Object motion fully obeys physics, smooth and natural. Visual quality is on par with or better than the original.
|
| 49 |
+
- **3 (Medium):** Some quality degradation exists (e.g., slight blurring, localized flickering), but all objects remain clearly identifiable. The video's overall structure is intact despite imperfections.
|
| 50 |
+
- **2 (Poor):** Significant quality degradation with obvious artifacts, distortion, or frame-to-frame inconsistency. Some object outlines deform, motion appears unnatural, affecting viewing experience.
|
| 51 |
+
- **1 (Unusable):** Video quality completely breaks down. Objects are severely deformed or unrecognizable, serious physics violations (e.g., person walking through walls, objects shattering spontaneously), heavy noise or complete blur.
|
| 52 |
+
|
| 53 |
+
**Important notes:**
|
| 54 |
+
- A sudden scene transition mid-video (e.g., white background abruptly becoming a construction site) counts as a physics/consistency violation — score ≤ 3.
|
| 55 |
+
- If the edit did NOT happen (original is preserved), RQ can still be high if the video itself looks fine — evaluate the video's visual quality independently.
|
| 56 |
+
- Evaluate temporal artifacts carefully: a single frame of flickering is minor (score 3), persistent warping or morphing is severe (score 1-2).
|
| 57 |
+
|
| 58 |
+
Render Quality score (integer 1-4): <|RQ_reward|>
|
| 59 |
+
|
| 60 |
+
---
|
| 61 |
+
|
| 62 |
+
## Dimension 3: Edit Exclusivity (EE)
|
| 63 |
+
**Core question:** Did the model ONLY perform the specified edit, without making unintended changes to other parts of the video?
|
| 64 |
+
|
| 65 |
+
Evaluation criteria:
|
| 66 |
+
- Over-editing: When editing a foreground object, did the background, lighting, or other unrelated objects change?
|
| 67 |
+
- Scene consistency: Are pixels, textures, and structures in non-edited regions preserved?
|
| 68 |
+
- Camera trajectory: Was the original camera movement preserved? (Changing camera motion when not instructed is over-editing.)
|
| 69 |
+
- Identity preservation: Do unedited people maintain their facial features, expressions, and body movements?
|
| 70 |
+
|
| 71 |
+
Scoring rubric:
|
| 72 |
+
- **4 (Perfect):** Strict exclusivity maintained. Only the target region specified by the instruction changed. All other regions (background, unrelated objects) remain identical to the original. Tiny pixel-level differences invisible to the eye are acceptable.
|
| 73 |
+
- **3 (Medium):** Visible over-editing occurred. Non-target areas show noticeable changes, but overall scene layout and unrelated object consistency are still preserved. E.g., replaced a cup on a table but the table style also changed, or a background window disappeared.
|
| 74 |
+
- **2 (Poor):** The overall scene or multiple unrelated objects changed significantly.
|
| 75 |
+
- **1 (Complete failure):** No exclusivity at all. The entire video looks like a completely new video. The surrounding scene changed drastically, or more than three unrelated objects underwent serious alterations.
|
| 76 |
+
|
| 77 |
+
**Important notes:**
|
| 78 |
+
- Camera trajectory changes (when not instructed) are over-editing — if the original video had camera motion and the edited video is static (or vice versa), penalize EE.
|
| 79 |
+
- For style transfer instructions (e.g., "turn into cyberpunk style"), it is expected that the entire visual style changes — this is NOT over-editing. But if text content or distinct object identities are destroyed during style transfer, that IS over-editing (score ≤ 3).
|
| 80 |
+
- If the edit failed (IF=1) but the rest of the video also changed, EE should still be scored low.
|
| 81 |
+
|
| 82 |
+
Edit Exclusivity score (integer 1-4): <|EE_reward|>
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def build_prompt(instruction: str) -> str:
|
| 87 |
+
"""Build the evaluation prompt from an editing instruction.
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
instruction: The video editing instruction text.
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
The formatted prompt string with special reward tokens.
|
| 94 |
+
"""
|
| 95 |
+
return EDITREWARD_V2_SPECIAL.format(text_prompt=instruction)
|
vefx_reward/vision_process.py
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Video processing utilities for VEFX-Reward.
|
| 3 |
+
Handles video loading, frame sampling, and resizing for Qwen3-VL input.
|
| 4 |
+
Adapted from qwen-vl-utils (https://github.com/kq-chen/qwen-vl-utils).
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import base64
|
| 10 |
+
import logging
|
| 11 |
+
import math
|
| 12 |
+
import os
|
| 13 |
+
import sys
|
| 14 |
+
import warnings
|
| 15 |
+
from functools import lru_cache
|
| 16 |
+
from io import BytesIO
|
| 17 |
+
|
| 18 |
+
import requests
|
| 19 |
+
import torch
|
| 20 |
+
import torchvision
|
| 21 |
+
from packaging import version
|
| 22 |
+
from PIL import Image
|
| 23 |
+
from torchvision import io, transforms
|
| 24 |
+
from torchvision.transforms import InterpolationMode
|
| 25 |
+
|
| 26 |
+
logger = logging.getLogger(__name__)
|
| 27 |
+
|
| 28 |
+
IMAGE_FACTOR = 28
|
| 29 |
+
MIN_PIXELS = 4 * 28 * 28
|
| 30 |
+
MAX_PIXELS = 16384 * 28 * 28
|
| 31 |
+
MAX_RATIO = 200
|
| 32 |
+
|
| 33 |
+
VIDEO_MIN_PIXELS = 128 * 28 * 28
|
| 34 |
+
VIDEO_MAX_PIXELS = 768 * 28 * 28
|
| 35 |
+
VIDEO_TOTAL_PIXELS = 24576 * 28 * 28
|
| 36 |
+
FRAME_FACTOR = 2
|
| 37 |
+
FPS = 2.0
|
| 38 |
+
FPS_MIN_FRAMES = 4
|
| 39 |
+
FPS_MAX_FRAMES = 768
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def round_by_factor(number: int, factor: int) -> int:
|
| 43 |
+
return round(number / factor) * factor
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def ceil_by_factor(number: int, factor: int) -> int:
|
| 47 |
+
return math.ceil(number / factor) * factor
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def floor_by_factor(number: int, factor: int) -> int:
|
| 51 |
+
return math.floor(number / factor) * factor
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def smart_resize(
|
| 55 |
+
height: int, width: int, factor: int = IMAGE_FACTOR,
|
| 56 |
+
min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS,
|
| 57 |
+
) -> tuple[int, int]:
|
| 58 |
+
"""Resize dimensions to be divisible by factor while respecting pixel budget."""
|
| 59 |
+
if max(height, width) / min(height, width) > MAX_RATIO:
|
| 60 |
+
raise ValueError(
|
| 61 |
+
f"absolute aspect ratio must be smaller than {MAX_RATIO}, "
|
| 62 |
+
f"got {max(height, width) / min(height, width)}"
|
| 63 |
+
)
|
| 64 |
+
h_bar = max(factor, round_by_factor(height, factor))
|
| 65 |
+
w_bar = max(factor, round_by_factor(width, factor))
|
| 66 |
+
if h_bar * w_bar > max_pixels:
|
| 67 |
+
beta = math.sqrt((height * width) / max_pixels)
|
| 68 |
+
h_bar = floor_by_factor(height / beta, factor)
|
| 69 |
+
w_bar = floor_by_factor(width / beta, factor)
|
| 70 |
+
elif h_bar * w_bar < min_pixels:
|
| 71 |
+
beta = math.sqrt(min_pixels / (height * width))
|
| 72 |
+
h_bar = ceil_by_factor(height * beta, factor)
|
| 73 |
+
w_bar = ceil_by_factor(width * beta, factor)
|
| 74 |
+
return h_bar, w_bar
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def smart_nframes(ele: dict, total_frames: int, video_fps: int | float) -> int:
|
| 78 |
+
"""Calculate the number of frames to extract based on fps or nframes config."""
|
| 79 |
+
assert not ("fps" in ele and "nframes" in ele), "Only accept either `fps` or `nframes`"
|
| 80 |
+
if "nframes" in ele:
|
| 81 |
+
nframes = round_by_factor(ele["nframes"], FRAME_FACTOR)
|
| 82 |
+
else:
|
| 83 |
+
fps = ele.get("fps", FPS)
|
| 84 |
+
min_frames = ceil_by_factor(ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR)
|
| 85 |
+
max_frames = floor_by_factor(
|
| 86 |
+
ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), FRAME_FACTOR
|
| 87 |
+
)
|
| 88 |
+
nframes = total_frames / video_fps * fps
|
| 89 |
+
nframes = min(max(nframes, min_frames), max_frames)
|
| 90 |
+
nframes = round_by_factor(nframes, FRAME_FACTOR)
|
| 91 |
+
if nframes > total_frames:
|
| 92 |
+
nframes = total_frames
|
| 93 |
+
if not (FRAME_FACTOR <= nframes <= total_frames):
|
| 94 |
+
raise ValueError(f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}.")
|
| 95 |
+
return nframes
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def _read_video_torchvision(ele: dict) -> tuple[torch.Tensor, dict]:
|
| 99 |
+
"""Read video using torchvision.io.read_video. Returns (T, C, H, W) tensor."""
|
| 100 |
+
video_path = ele["video"]
|
| 101 |
+
if version.parse(torchvision.__version__) < version.parse("0.19.0"):
|
| 102 |
+
if "http://" in video_path or "https://" in video_path:
|
| 103 |
+
warnings.warn("torchvision < 0.19.0 does not support http/https video path.")
|
| 104 |
+
if "file://" in video_path:
|
| 105 |
+
video_path = video_path[7:]
|
| 106 |
+
video, audio, info = io.read_video(
|
| 107 |
+
video_path,
|
| 108 |
+
start_pts=ele.get("video_start", 0.0),
|
| 109 |
+
end_pts=ele.get("video_end", None),
|
| 110 |
+
pts_unit="sec",
|
| 111 |
+
output_format="TCHW",
|
| 112 |
+
)
|
| 113 |
+
total_frames, video_fps = video.size(0), info["video_fps"]
|
| 114 |
+
nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
|
| 115 |
+
idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
|
| 116 |
+
video = video[idx]
|
| 117 |
+
metadata = {
|
| 118 |
+
"total_num_frames": total_frames,
|
| 119 |
+
"fps": video_fps,
|
| 120 |
+
"frames_indices": idx,
|
| 121 |
+
}
|
| 122 |
+
return video, metadata
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def is_decord_available() -> bool:
|
| 126 |
+
import importlib.util
|
| 127 |
+
return importlib.util.find_spec("decord") is not None
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def _read_video_decord(ele: dict) -> tuple[torch.Tensor, dict]:
|
| 131 |
+
"""Read video using decord.VideoReader. Returns (T, C, H, W) tensor."""
|
| 132 |
+
import decord
|
| 133 |
+
video_path = ele["video"]
|
| 134 |
+
vr = decord.VideoReader(video_path)
|
| 135 |
+
total_frames, video_fps = len(vr), vr.get_avg_fps()
|
| 136 |
+
nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
|
| 137 |
+
idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
|
| 138 |
+
video = vr.get_batch(idx).asnumpy()
|
| 139 |
+
video = torch.tensor(video).permute(0, 3, 1, 2) # NHWC → TCHW
|
| 140 |
+
metadata = {
|
| 141 |
+
"total_num_frames": total_frames,
|
| 142 |
+
"fps": video_fps,
|
| 143 |
+
"frames_indices": idx,
|
| 144 |
+
}
|
| 145 |
+
return video, metadata
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
VIDEO_READER_BACKENDS = {
|
| 149 |
+
"decord": _read_video_decord,
|
| 150 |
+
"torchvision": _read_video_torchvision,
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
FORCE_QWENVL_VIDEO_READER = os.getenv("FORCE_QWENVL_VIDEO_READER", None)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
@lru_cache(maxsize=1)
|
| 157 |
+
def get_video_reader_backend() -> str:
|
| 158 |
+
if FORCE_QWENVL_VIDEO_READER is not None:
|
| 159 |
+
video_reader_backend = FORCE_QWENVL_VIDEO_READER
|
| 160 |
+
elif is_decord_available():
|
| 161 |
+
video_reader_backend = "decord"
|
| 162 |
+
else:
|
| 163 |
+
video_reader_backend = "torchvision"
|
| 164 |
+
print(f"vefx-reward using {video_reader_backend} to read video.", file=sys.stderr)
|
| 165 |
+
return video_reader_backend
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def fetch_image(ele: dict[str, str | Image.Image], size_factor: int = IMAGE_FACTOR) -> Image.Image:
|
| 169 |
+
if "image" in ele:
|
| 170 |
+
image = ele["image"]
|
| 171 |
+
else:
|
| 172 |
+
image = ele["image_url"]
|
| 173 |
+
image_obj = None
|
| 174 |
+
if isinstance(image, Image.Image):
|
| 175 |
+
image_obj = image
|
| 176 |
+
elif image.startswith("http://") or image.startswith("https://"):
|
| 177 |
+
image_obj = Image.open(requests.get(image, stream=True).raw)
|
| 178 |
+
elif image.startswith("file://"):
|
| 179 |
+
image_obj = Image.open(image[7:])
|
| 180 |
+
elif image.startswith("data:image"):
|
| 181 |
+
if "base64," in image:
|
| 182 |
+
_, base64_data = image.split("base64,", 1)
|
| 183 |
+
data = base64.b64decode(base64_data)
|
| 184 |
+
image_obj = Image.open(BytesIO(data))
|
| 185 |
+
else:
|
| 186 |
+
image_obj = Image.open(image)
|
| 187 |
+
if image_obj is None:
|
| 188 |
+
raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
|
| 189 |
+
image = image_obj.convert("RGB")
|
| 190 |
+
if "resized_height" in ele and "resized_width" in ele:
|
| 191 |
+
resized_height, resized_width = smart_resize(
|
| 192 |
+
ele["resized_height"], ele["resized_width"], factor=size_factor,
|
| 193 |
+
)
|
| 194 |
+
else:
|
| 195 |
+
width, height = image.size
|
| 196 |
+
min_pixels = ele.get("min_pixels", MIN_PIXELS)
|
| 197 |
+
max_pixels = ele.get("max_pixels", MAX_PIXELS)
|
| 198 |
+
resized_height, resized_width = smart_resize(
|
| 199 |
+
height, width, factor=size_factor, min_pixels=min_pixels, max_pixels=max_pixels,
|
| 200 |
+
)
|
| 201 |
+
image = image.resize((resized_width, resized_height))
|
| 202 |
+
return image
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def fetch_video(ele: dict, image_factor: int = IMAGE_FACTOR) -> tuple[torch.Tensor | list[Image.Image], dict | None]:
|
| 206 |
+
if isinstance(ele["video"], str):
|
| 207 |
+
video_reader_backend = get_video_reader_backend()
|
| 208 |
+
video, metadata = VIDEO_READER_BACKENDS[video_reader_backend](ele)
|
| 209 |
+
nframes, _, height, width = video.shape
|
| 210 |
+
min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS)
|
| 211 |
+
total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS)
|
| 212 |
+
max_pixels = max(min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR), int(min_pixels * 1.05))
|
| 213 |
+
max_pixels = ele.get("max_pixels", max_pixels)
|
| 214 |
+
if "resized_height" in ele and "resized_width" in ele:
|
| 215 |
+
resized_height, resized_width = smart_resize(
|
| 216 |
+
ele["resized_height"], ele["resized_width"], factor=image_factor,
|
| 217 |
+
)
|
| 218 |
+
else:
|
| 219 |
+
resized_height, resized_width = smart_resize(
|
| 220 |
+
height, width, factor=image_factor,
|
| 221 |
+
min_pixels=min_pixels, max_pixels=max_pixels,
|
| 222 |
+
)
|
| 223 |
+
video = transforms.functional.resize(
|
| 224 |
+
video, [resized_height, resized_width],
|
| 225 |
+
interpolation=InterpolationMode.BICUBIC, antialias=True,
|
| 226 |
+
).float()
|
| 227 |
+
return video, metadata
|
| 228 |
+
else:
|
| 229 |
+
assert isinstance(ele["video"], (list, tuple))
|
| 230 |
+
process_info = ele.copy()
|
| 231 |
+
process_info.pop("type", None)
|
| 232 |
+
process_info.pop("video", None)
|
| 233 |
+
images = [
|
| 234 |
+
fetch_image({"image": video_element, **process_info}, size_factor=image_factor)
|
| 235 |
+
for video_element in ele["video"]
|
| 236 |
+
]
|
| 237 |
+
nframes = ceil_by_factor(len(images), FRAME_FACTOR)
|
| 238 |
+
if len(images) < nframes:
|
| 239 |
+
images.extend([images[-1]] * (nframes - len(images)))
|
| 240 |
+
return images, None
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def extract_vision_info(conversations: list[dict] | list[list[dict]]) -> list[dict]:
|
| 244 |
+
vision_infos = []
|
| 245 |
+
if isinstance(conversations[0], dict):
|
| 246 |
+
conversations = [conversations]
|
| 247 |
+
for conversation in conversations:
|
| 248 |
+
for message in conversation:
|
| 249 |
+
if isinstance(message["content"], list):
|
| 250 |
+
for ele in message["content"]:
|
| 251 |
+
if (
|
| 252 |
+
"image" in ele
|
| 253 |
+
or "image_url" in ele
|
| 254 |
+
or "video" in ele
|
| 255 |
+
or ele["type"] in ("image", "image_url", "video")
|
| 256 |
+
):
|
| 257 |
+
vision_infos.append(ele)
|
| 258 |
+
return vision_infos
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def process_vision_info(
|
| 262 |
+
conversations: list[dict] | list[list[dict]],
|
| 263 |
+
) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | None, list[dict] | None]:
|
| 264 |
+
"""Process vision info from conversation messages, loading images and videos."""
|
| 265 |
+
vision_infos = extract_vision_info(conversations)
|
| 266 |
+
image_inputs = []
|
| 267 |
+
video_inputs = []
|
| 268 |
+
video_metadata_list = []
|
| 269 |
+
for vision_info in vision_infos:
|
| 270 |
+
if "image" in vision_info or "image_url" in vision_info:
|
| 271 |
+
image_inputs.append(fetch_image(vision_info))
|
| 272 |
+
elif "video" in vision_info:
|
| 273 |
+
video, metadata = fetch_video(vision_info)
|
| 274 |
+
video_inputs.append(video)
|
| 275 |
+
video_metadata_list.append(metadata)
|
| 276 |
+
else:
|
| 277 |
+
raise ValueError("image, image_url or video should in content.")
|
| 278 |
+
if len(image_inputs) == 0:
|
| 279 |
+
image_inputs = None
|
| 280 |
+
if len(video_inputs) == 0:
|
| 281 |
+
video_inputs = None
|
| 282 |
+
video_metadata_list = None
|
| 283 |
+
return image_inputs, video_inputs, video_metadata_list
|