update
Browse files- LICENSE +21 -0
- environment.yaml +223 -0
- evaluate/README.md +107 -0
- evaluate/benchmark.py +463 -0
- evaluate/evaluate.py +203 -0
- generate/README.md +104 -0
- generate/__init__.py +0 -0
- generate/gen_images_from_prompt.py +117 -0
- generate/generator.py +211 -0
- generate/utils/__init__.py +0 -0
- generate/utils/pipelines.py +282 -0
- generate/utils/utils.py +58 -0
- hpsv3/__init__.py +1 -0
- hpsv3/cohp/__init__.py +0 -0
- hpsv3/cohp/cohp_all.py +290 -0
- hpsv3/cohp/generator.py +64 -0
- hpsv3/cohp/run_cohp.py +244 -0
- hpsv3/cohp/utils_cohp/__init__.py +0 -0
- hpsv3/cohp/utils_cohp/image2image_pipeline.py +65 -0
- hpsv3/cohp/utils_cohp/pipelines.py +290 -0
- hpsv3/cohp/utils_cohp/utils.py +53 -0
- hpsv3/config/HPSv3_7B.yaml +60 -0
- hpsv3/config/ds_config/zero0.json +19 -0
- hpsv3/config/ds_config/zero2.json +23 -0
- hpsv3/config/ds_config/zero3.json +28 -0
- hpsv3/dataset/data_collator_qwen.py +205 -0
- hpsv3/dataset/pairwise_dataset.py +77 -0
- hpsv3/dataset/utils.py +426 -0
- hpsv3/inference.py +167 -0
- hpsv3/model/differentiable_image_processor.py +629 -0
- hpsv3/model/qwen2vl_trainer.py +971 -0
- hpsv3/model/test_differentiable.py +212 -0
- hpsv3/train.py +315 -0
- hpsv3/utils/parser.py +150 -0
- hpsv3/utils/training_utils.py +158 -0
- pretrained_models/download_pretrained_models.sh +63 -0
- pyproject.toml +64 -0
- requirements.txt +2 -2
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2024 HPSv3 Team
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
environment.yaml
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: hpsv3
|
| 2 |
+
channels:
|
| 3 |
+
- nvidia
|
| 4 |
+
- conda-forge
|
| 5 |
+
- defaults
|
| 6 |
+
dependencies:
|
| 7 |
+
- _libgcc_mutex=0.1=conda_forge
|
| 8 |
+
- _openmp_mutex=4.5=2_gnu
|
| 9 |
+
- bzip2=1.0.8=h4bc722e_7
|
| 10 |
+
- ca-certificates=2025.4.26=hbd8a1cb_0
|
| 11 |
+
- ld_impl_linux-64=2.43=h712a8e2_4
|
| 12 |
+
- libexpat=2.7.0=h5888daf_0
|
| 13 |
+
- libffi=3.4.6=h2dba641_1
|
| 14 |
+
- libgcc=15.1.0=h767d61c_2
|
| 15 |
+
- libgcc-ng=15.1.0=h69a702a_2
|
| 16 |
+
- libgomp=15.1.0=h767d61c_2
|
| 17 |
+
- liblzma=5.8.1=hb9d3cd8_1
|
| 18 |
+
- libnsl=2.0.1=hd590300_0
|
| 19 |
+
- libsqlite=3.50.0=hee588c1_0
|
| 20 |
+
- libuuid=2.38.1=h0b41bf4_0
|
| 21 |
+
- libxcrypt=4.4.36=hd590300_1
|
| 22 |
+
- libzlib=1.3.1=hb9d3cd8_2
|
| 23 |
+
- ncurses=6.5=h2d0b736_3
|
| 24 |
+
- openssl=3.5.0=h7b32b05_1
|
| 25 |
+
- pip=25.1.1=pyh8b19718_0
|
| 26 |
+
- python=3.10.17=hd6af730_0_cpython
|
| 27 |
+
- readline=8.2=h8c095d6_2
|
| 28 |
+
- setuptools=80.8.0=pyhff2d567_0
|
| 29 |
+
- tk=8.6.13=noxft_hd72426e_102
|
| 30 |
+
- wheel=0.45.1=pyhd8ed1ab_1
|
| 31 |
+
- pip:
|
| 32 |
+
- absl-py==2.3.0
|
| 33 |
+
- accelerate==1.8.0
|
| 34 |
+
- aiohappyeyeballs==2.6.1
|
| 35 |
+
- aiohttp==3.12.12
|
| 36 |
+
- aiosignal==1.3.2
|
| 37 |
+
- annotated-types==0.7.0
|
| 38 |
+
- antlr4-python3-runtime==4.9.3
|
| 39 |
+
- anyio==4.9.0
|
| 40 |
+
- argon2-cffi==23.1.0
|
| 41 |
+
- argon2-cffi-bindings==21.2.0
|
| 42 |
+
- arrow==1.3.0
|
| 43 |
+
- asttokens==3.0.0
|
| 44 |
+
- async-lru==2.0.5
|
| 45 |
+
- async-timeout==5.0.1
|
| 46 |
+
- attrs==25.3.0
|
| 47 |
+
- av==14.4.0
|
| 48 |
+
- babel==2.17.0
|
| 49 |
+
- beautifulsoup4==4.13.4
|
| 50 |
+
- bleach==6.2.0
|
| 51 |
+
- botocore==1.38.35
|
| 52 |
+
- certifi==2025.4.26
|
| 53 |
+
- cffi==1.17.1
|
| 54 |
+
- charset-normalizer==3.4.2
|
| 55 |
+
- comm==0.2.2
|
| 56 |
+
- contourpy==1.3.2
|
| 57 |
+
- cycler==0.12.1
|
| 58 |
+
- datasets==3.6.0
|
| 59 |
+
- debugpy==1.8.14
|
| 60 |
+
- decorator==5.2.1
|
| 61 |
+
- deepspeed==0.15.4
|
| 62 |
+
- defusedxml==0.7.1
|
| 63 |
+
- diffusers==0.33.1
|
| 64 |
+
- dill==0.3.8
|
| 65 |
+
- docstring-parser==0.16
|
| 66 |
+
- einops==0.8.1
|
| 67 |
+
- exceptiongroup==1.3.0
|
| 68 |
+
- executing==2.2.0
|
| 69 |
+
- fastjsonschema==2.21.1
|
| 70 |
+
- filelock==3.13.1
|
| 71 |
+
- fire==0.7.0
|
| 72 |
+
- fonttools==4.58.1
|
| 73 |
+
- fqdn==1.5.1
|
| 74 |
+
- frozenlist==1.7.0
|
| 75 |
+
- fsspec==2024.6.1
|
| 76 |
+
- grpcio==1.72.1
|
| 77 |
+
- h11==0.16.0
|
| 78 |
+
- hf-xet==1.1.3
|
| 79 |
+
- hjson==3.1.0
|
| 80 |
+
- httpcore==1.0.9
|
| 81 |
+
- httpx==0.28.1
|
| 82 |
+
- huggingface-hub==0.32.4
|
| 83 |
+
- idna==3.10
|
| 84 |
+
- imageio==2.37.0
|
| 85 |
+
- importlib-metadata==8.7.0
|
| 86 |
+
- ipykernel==6.29.5
|
| 87 |
+
- ipython==8.36.0
|
| 88 |
+
- ipywidgets==8.1.7
|
| 89 |
+
- isoduration==20.11.0
|
| 90 |
+
- jedi==0.19.2
|
| 91 |
+
- jinja2==3.1.6
|
| 92 |
+
- jmespath==1.0.1
|
| 93 |
+
- json5==0.12.0
|
| 94 |
+
- jsonpointer==3.0.0
|
| 95 |
+
- jsonschema==4.24.0
|
| 96 |
+
- jsonschema-specifications==2025.4.1
|
| 97 |
+
- jupyter==1.1.1
|
| 98 |
+
- jupyter-client==8.6.3
|
| 99 |
+
- jupyter-console==6.6.3
|
| 100 |
+
- jupyter-core==5.8.1
|
| 101 |
+
- jupyter-events==0.12.0
|
| 102 |
+
- jupyter-lsp==2.2.5
|
| 103 |
+
- jupyter-server==2.16.0
|
| 104 |
+
- jupyter-server-terminals==0.5.3
|
| 105 |
+
- jupyterlab==4.4.3
|
| 106 |
+
- jupyterlab-pygments==0.3.0
|
| 107 |
+
- jupyterlab-server==2.27.3
|
| 108 |
+
- jupyterlab-widgets==3.0.15
|
| 109 |
+
- kiwisolver==1.4.8
|
| 110 |
+
- markdown==3.8
|
| 111 |
+
- markdown-it-py==3.0.0
|
| 112 |
+
- markupsafe==3.0.2
|
| 113 |
+
- matplotlib==3.10.3
|
| 114 |
+
- matplotlib-inline==0.1.7
|
| 115 |
+
- mdurl==0.1.2
|
| 116 |
+
- mistune==3.1.3
|
| 117 |
+
- mpmath==1.3.0
|
| 118 |
+
- msgpack==1.1.0
|
| 119 |
+
- multidict==6.4.4
|
| 120 |
+
- multiprocess==0.70.16
|
| 121 |
+
- nbclient==0.10.2
|
| 122 |
+
- nbconvert==7.16.6
|
| 123 |
+
- nbformat==5.10.4
|
| 124 |
+
- nest-asyncio==1.6.0
|
| 125 |
+
- networkx==3.3
|
| 126 |
+
- ninja==1.11.1.4
|
| 127 |
+
- notebook==7.4.3
|
| 128 |
+
- notebook-shim==0.2.4
|
| 129 |
+
- numpy==2.1.2
|
| 130 |
+
- nvidia-cublas-cu11==11.11.3.6
|
| 131 |
+
- nvidia-cuda-cupti-cu11==11.8.87
|
| 132 |
+
- nvidia-cuda-nvrtc-cu11==11.8.89
|
| 133 |
+
- nvidia-cuda-runtime-cu11==11.8.89
|
| 134 |
+
- nvidia-cudnn-cu11==9.1.0.70
|
| 135 |
+
- nvidia-cufft-cu11==10.9.0.58
|
| 136 |
+
- nvidia-curand-cu11==10.3.0.86
|
| 137 |
+
- nvidia-cusolver-cu11==11.4.1.48
|
| 138 |
+
- nvidia-cusparse-cu11==11.7.5.86
|
| 139 |
+
- nvidia-ml-py==12.575.51
|
| 140 |
+
- nvidia-nccl-cu11==2.21.5
|
| 141 |
+
- nvidia-nvtx-cu11==11.8.86
|
| 142 |
+
- omegaconf==2.3.0
|
| 143 |
+
- opencv-python==4.11.0.86
|
| 144 |
+
- overrides==7.7.0
|
| 145 |
+
- packaging==25.0
|
| 146 |
+
- pandas==2.3.0
|
| 147 |
+
- pandocfilters==1.5.1
|
| 148 |
+
- parso==0.8.4
|
| 149 |
+
- peft==0.10.0
|
| 150 |
+
- pexpect==4.9.0
|
| 151 |
+
- pillow==11.0.0
|
| 152 |
+
- platformdirs==4.3.8
|
| 153 |
+
- prometheus-client==0.22.0
|
| 154 |
+
- prompt-toolkit==3.0.51
|
| 155 |
+
- propcache==0.3.2
|
| 156 |
+
- protobuf==6.31.1
|
| 157 |
+
- psutil==7.0.0
|
| 158 |
+
- ptyprocess==0.7.0
|
| 159 |
+
- pure-eval==0.2.3
|
| 160 |
+
- py-cpuinfo==9.0.0
|
| 161 |
+
- pyarrow==20.0.0
|
| 162 |
+
- pycparser==2.22
|
| 163 |
+
- pydantic==2.11.5
|
| 164 |
+
- pydantic-core==2.33.2
|
| 165 |
+
- pygments==2.19.1
|
| 166 |
+
- pyparsing==3.2.3
|
| 167 |
+
- python-dateutil==2.9.0.post0
|
| 168 |
+
- python-json-logger==3.3.0
|
| 169 |
+
- pytz==2025.2
|
| 170 |
+
- pyyaml==6.0.2
|
| 171 |
+
- pyzmq==26.4.0
|
| 172 |
+
- prettytable==3.8.0
|
| 173 |
+
- qwen-vl-utils==0.0.11
|
| 174 |
+
- referencing==0.36.2
|
| 175 |
+
- regex==2024.11.6
|
| 176 |
+
- requests==2.32.3
|
| 177 |
+
- rfc3339-validator==0.1.4
|
| 178 |
+
- rfc3986-validator==0.1.1
|
| 179 |
+
- rich==14.0.0
|
| 180 |
+
- rpds-py==0.25.1
|
| 181 |
+
- safetensors==0.5.3
|
| 182 |
+
- send2trash==1.8.3
|
| 183 |
+
- sentencepiece==0.2.0
|
| 184 |
+
- shtab==1.7.2
|
| 185 |
+
- six==1.17.0
|
| 186 |
+
- sniffio==1.3.1
|
| 187 |
+
- soupsieve==2.7
|
| 188 |
+
- stack-data==0.6.3
|
| 189 |
+
- sympy==1.13.1
|
| 190 |
+
- tensorboard==2.19.0
|
| 191 |
+
- tensorboard-data-server==0.7.2
|
| 192 |
+
- termcolor==3.1.0
|
| 193 |
+
- terminado==0.18.1
|
| 194 |
+
- timm==1.0.15
|
| 195 |
+
- tinycss2==1.4.0
|
| 196 |
+
- tokenizers==0.20.3
|
| 197 |
+
- tomli==2.2.1
|
| 198 |
+
- torch==2.6.0
|
| 199 |
+
- torchaudio==2.6.0
|
| 200 |
+
- torchvision==0.21.0
|
| 201 |
+
- tornado==6.5.1
|
| 202 |
+
- tqdm==4.67.1
|
| 203 |
+
- traitlets==5.14.3
|
| 204 |
+
- transformers==4.45.2
|
| 205 |
+
- triton==3.2.0
|
| 206 |
+
- trl==0.8.6
|
| 207 |
+
- typeguard==4.4.3
|
| 208 |
+
- types-python-dateutil==2.9.0.20250516
|
| 209 |
+
- typing-extensions==4.14.0
|
| 210 |
+
- typing-inspection==0.4.1
|
| 211 |
+
- tyro==0.9.24
|
| 212 |
+
- tzdata==2025.2
|
| 213 |
+
- uri-template==1.3.0
|
| 214 |
+
- urllib3==2.4.0
|
| 215 |
+
- wcwidth==0.2.13
|
| 216 |
+
- webcolors==24.11.1
|
| 217 |
+
- webencodings==0.5.1
|
| 218 |
+
- websocket-client==1.8.0
|
| 219 |
+
- werkzeug==3.1.3
|
| 220 |
+
- widgetsnbextension==4.0.14
|
| 221 |
+
- xxhash==3.5.0
|
| 222 |
+
- yarl==1.20.1
|
| 223 |
+
- zipp==3.22.0
|
evaluate/README.md
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Model Performance Evaluation (`evaluate.py`)
|
| 2 |
+
|
| 3 |
+
This script is used to evaluate the model's performance on a test set. It can operate in two modes:
|
| 4 |
+
|
| 5 |
+
- **`pair`**: Calculates pairwise accuracy.
|
| 6 |
+
- **`ranking`**: Calculates ranking accuracy.
|
| 7 |
+
|
| 8 |
+
**Pair-wise Sample**
|
| 9 |
+
|
| 10 |
+
We set path1's image is better than path2's image for simplicity.
|
| 11 |
+
|
| 12 |
+
```json
|
| 13 |
+
[
|
| 14 |
+
{
|
| 15 |
+
"prompt": ".....",
|
| 16 |
+
"path1": ".....",
|
| 17 |
+
"path2": "....."
|
| 18 |
+
},
|
| 19 |
+
{
|
| 20 |
+
"prompt": ".....",
|
| 21 |
+
"path1": ".....",
|
| 22 |
+
"path2": "....."
|
| 23 |
+
},
|
| 24 |
+
...
|
| 25 |
+
]
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
**Rank-wise Sample**
|
| 29 |
+
|
| 30 |
+
```json
|
| 31 |
+
[
|
| 32 |
+
{
|
| 33 |
+
"id": "005658-0040",
|
| 34 |
+
"prompt": ".....",
|
| 35 |
+
"generations": [
|
| 36 |
+
"path to image1",
|
| 37 |
+
"path to image2",
|
| 38 |
+
"path to image3",
|
| 39 |
+
"path to image4"
|
| 40 |
+
],
|
| 41 |
+
"ranking": [
|
| 42 |
+
1,
|
| 43 |
+
2,
|
| 44 |
+
5,
|
| 45 |
+
3
|
| 46 |
+
]
|
| 47 |
+
},
|
| 48 |
+
...
|
| 49 |
+
]
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
### Usage
|
| 53 |
+
|
| 54 |
+
```bash
|
| 55 |
+
python evaluate/evaluate.py \
|
| 56 |
+
--test_json /path/to/your/test_data.json \
|
| 57 |
+
--config_path config/HPSv3_7B.yaml \
|
| 58 |
+
--checkpoint_path checkpoints/HPSv3_7B/model.pth \
|
| 59 |
+
--mode pair \
|
| 60 |
+
--batch_size 8 \
|
| 61 |
+
--num_processes 8
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
**Arguments:**
|
| 65 |
+
|
| 66 |
+
- `--test_json`: (Required) Path to the JSON file containing evaluation data.
|
| 67 |
+
- `--config_path`: (Required) Path to the model's configuration file.
|
| 68 |
+
- `--checkpoint_path`: (Required) Path to the model checkpoint.
|
| 69 |
+
- `--mode`: The evaluation mode. Can be `pair` or `ranking`. (Default: `pair`)
|
| 70 |
+
- `--batch_size`: Batch size for inference. (Default: 8)
|
| 71 |
+
- `--num_processes`: Number of parallel processes to use. (Default: 8)
|
| 72 |
+
|
| 73 |
+
---
|
| 74 |
+
|
| 75 |
+
## Reward Benchmarking (`benchmark.py`)
|
| 76 |
+
|
| 77 |
+
This script is used to run inference with a reward model over one or more folders of images. It calculates a reward score for each image based on its corresponding text prompt (expected in a `.txt` file with the same name). The script then outputs statistics (mean, std, min, max) for each folder and saves the detailed results to a JSON file.
|
| 78 |
+
|
| 79 |
+
It supports multiple reward models through the `--model_type` argument.
|
| 80 |
+
|
| 81 |
+
### Usage
|
| 82 |
+
|
| 83 |
+
The script is run using `argparse`. Below is a command-line example:
|
| 84 |
+
|
| 85 |
+
```bash
|
| 86 |
+
python evaluate/benchmark.py \
|
| 87 |
+
--config_path config/HPSv3_7B.yaml \
|
| 88 |
+
--checkpoint_path checkpoints/HPSv3_7B/model.pth \
|
| 89 |
+
--model_type hpsv3 \
|
| 90 |
+
--image_folders /path/to/images/folder1 /path/to/images/folder2 \
|
| 91 |
+
--output_path ./benchmark_results.json \
|
| 92 |
+
--batch_size 16 \
|
| 93 |
+
--num_processes 8
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
**Arguments:**
|
| 97 |
+
|
| 98 |
+
- `--config_path`: (Required) Path to the model's configuration file.
|
| 99 |
+
- `--checkpoint_path`: (Required) Path to the model checkpoint.
|
| 100 |
+
- `--model_type`: The reward model to use. Choices: `hpsv3`, `hpsv2`, `imagereward`. (Default: `hpsv3`)
|
| 101 |
+
- `--image_folders`: (Required) One or more paths to folders containing the images to benchmark.
|
| 102 |
+
- `--output_path`: (Required) Path to save the output JSON file with results.
|
| 103 |
+
- `--batch_size`: Batch size for processing. (Default: 16)
|
| 104 |
+
- `--num_processes`: Number of parallel processes to use. (Default: 8)
|
| 105 |
+
- `--num_machines`: For distributed inference, the total number of machines. (Default: 1)
|
| 106 |
+
- `--machine_id`: For distributed inference, the ID of the current machine. (Default: 0)
|
| 107 |
+
|
evaluate/benchmark.py
ADDED
|
@@ -0,0 +1,463 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import torch
|
| 4 |
+
import multiprocessing as mp
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
from hpsv3.inference import HPSv3RewardInferencer
|
| 7 |
+
import argparse
|
| 8 |
+
from collections import defaultdict
|
| 9 |
+
import glob
|
| 10 |
+
import numpy as np
|
| 11 |
+
from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer
|
| 12 |
+
from PIL import Image
|
| 13 |
+
import ImageReward as RM
|
| 14 |
+
from transformers import AutoProcessor, AutoModel
|
| 15 |
+
def initialize_model_hpsv2(device, cp):
|
| 16 |
+
model_dict = {}
|
| 17 |
+
model, preprocess_train, preprocess_val = create_model_and_transforms(
|
| 18 |
+
'ViT-H-14',
|
| 19 |
+
'laion2B-s32B-b79K',
|
| 20 |
+
precision='amp',
|
| 21 |
+
device=device,
|
| 22 |
+
jit=False,
|
| 23 |
+
force_quick_gelu=False,
|
| 24 |
+
force_custom_text=False,
|
| 25 |
+
force_patch_dropout=False,
|
| 26 |
+
force_image_size=None,
|
| 27 |
+
pretrained_image=False,
|
| 28 |
+
image_mean=None,
|
| 29 |
+
image_std=None,
|
| 30 |
+
light_augmentation=True,
|
| 31 |
+
aug_cfg={},
|
| 32 |
+
output_dict=True,
|
| 33 |
+
with_score_predictor=False,
|
| 34 |
+
with_region_predictor=False
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
checkpoint = torch.load(cp, map_location=device, weights_only=False)
|
| 38 |
+
model.load_state_dict(checkpoint['state_dict'])
|
| 39 |
+
model = model.to(device)
|
| 40 |
+
model.eval()
|
| 41 |
+
tokenizer = get_tokenizer('ViT-H-14')
|
| 42 |
+
|
| 43 |
+
model_dict['model'] = model
|
| 44 |
+
model_dict['preprocess_val'] = preprocess_val
|
| 45 |
+
return model_dict, tokenizer
|
| 46 |
+
|
| 47 |
+
def initialize_pickscore(device, checkpoint_path):
|
| 48 |
+
processor = AutoProcessor.from_pretrained('laion/CLIP-ViT-H-14-laion2B-s32B-b79K')
|
| 49 |
+
model = AutoModel.from_pretrained(checkpoint_path).eval().to(device)
|
| 50 |
+
return model, processor
|
| 51 |
+
|
| 52 |
+
def initialize_aesthetic_model():
|
| 53 |
+
import open_clip
|
| 54 |
+
from os.path import expanduser
|
| 55 |
+
from urllib.request import urlretrieve
|
| 56 |
+
import torch.nn as nn
|
| 57 |
+
|
| 58 |
+
def get_aesthetic_model(clip_model="vit_l_14"):
|
| 59 |
+
"""Load the aesthetic model with caching"""
|
| 60 |
+
|
| 61 |
+
home = expanduser("~")
|
| 62 |
+
cache_folder = home + "/.cache/emb_reader"
|
| 63 |
+
path_to_model = cache_folder + "/sa_0_4_"+clip_model+"_linear.pth"
|
| 64 |
+
if not os.path.exists(path_to_model):
|
| 65 |
+
os.makedirs(cache_folder, exist_ok=True)
|
| 66 |
+
url_model = (
|
| 67 |
+
"https://github.com/LAION-AI/aesthetic-predictor/blob/main/sa_0_4_"+clip_model+"_linear.pth?raw=true"
|
| 68 |
+
)
|
| 69 |
+
urlretrieve(url_model, path_to_model)
|
| 70 |
+
# Create appropriate linear layer
|
| 71 |
+
if clip_model == "vit_l_14":
|
| 72 |
+
m = nn.Linear(768, 1)
|
| 73 |
+
elif clip_model == "vit_b_32":
|
| 74 |
+
m = nn.Linear(512, 1)
|
| 75 |
+
else:
|
| 76 |
+
raise ValueError()
|
| 77 |
+
m.load_state_dict(torch.load(path_to_model))
|
| 78 |
+
m.eval()
|
| 79 |
+
return m
|
| 80 |
+
|
| 81 |
+
model, _, preprocess = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai')
|
| 82 |
+
amodel = get_aesthetic_model(clip_model="vit_l_14")
|
| 83 |
+
return model, preprocess, amodel
|
| 84 |
+
|
| 85 |
+
def initialize_clip(device):
|
| 86 |
+
"""Initialize the CLIP model and processor."""
|
| 87 |
+
model = AutoModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
|
| 88 |
+
processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
|
| 89 |
+
return model.to(device), processor
|
| 90 |
+
|
| 91 |
+
def score_hpsv2_batch(model_dict, tokenizer, device, img_paths: list, prompts: list) -> list:
|
| 92 |
+
model = model_dict['model']
|
| 93 |
+
preprocess_val = model_dict['preprocess_val']
|
| 94 |
+
|
| 95 |
+
# 批量处理图片
|
| 96 |
+
images = [preprocess_val(Image.open(p)).unsqueeze(0)[:,:3,:,:] for p in img_paths]
|
| 97 |
+
images = torch.cat(images, dim=0).to(device=device)
|
| 98 |
+
texts = tokenizer(prompts).to(device=device)
|
| 99 |
+
with torch.no_grad():
|
| 100 |
+
outputs = model(images, texts)
|
| 101 |
+
image_features, text_features = outputs["image_features"], outputs["text_features"]
|
| 102 |
+
logits_per_image = image_features @ text_features.T
|
| 103 |
+
hps_scores = torch.diagonal(logits_per_image).cpu()
|
| 104 |
+
return hps_scores
|
| 105 |
+
|
| 106 |
+
def score_pick_score_batch(prompts, images, model, processor, device):
|
| 107 |
+
# preprocess
|
| 108 |
+
pil_images = [Image.open(p) for p in images]
|
| 109 |
+
image_inputs = processor(
|
| 110 |
+
images=pil_images,
|
| 111 |
+
padding=True,
|
| 112 |
+
truncation=True,
|
| 113 |
+
max_length=77,
|
| 114 |
+
return_tensors="pt",
|
| 115 |
+
).to(device)
|
| 116 |
+
|
| 117 |
+
text_inputs = processor(
|
| 118 |
+
text=prompts,
|
| 119 |
+
padding=True,
|
| 120 |
+
truncation=True,
|
| 121 |
+
max_length=77,
|
| 122 |
+
return_tensors="pt",
|
| 123 |
+
).to(device)
|
| 124 |
+
|
| 125 |
+
with torch.no_grad():
|
| 126 |
+
# embed
|
| 127 |
+
image_embs = model.get_image_features(**image_inputs)
|
| 128 |
+
image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
|
| 129 |
+
|
| 130 |
+
text_embs = model.get_text_features(**text_inputs)
|
| 131 |
+
text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)
|
| 132 |
+
# score
|
| 133 |
+
scores = model.logit_scale.exp() * (text_embs @ image_embs.T)
|
| 134 |
+
scores = torch.diagonal(scores).cpu()
|
| 135 |
+
|
| 136 |
+
return scores
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def score_aesthetic_batch(model, preprocess, aesthetic_model, device, img_paths: list) -> list:
|
| 140 |
+
"""Scores a batch of images using the aesthetic model."""
|
| 141 |
+
images = [preprocess(Image.open(p)).unsqueeze(0) for p in img_paths]
|
| 142 |
+
images = torch.cat(images, dim=0).to(device=device)
|
| 143 |
+
with torch.no_grad():
|
| 144 |
+
feat = model.encode_image(images)
|
| 145 |
+
feat = feat / feat.norm(dim=-1, keepdim=True)
|
| 146 |
+
pred = aesthetic_model(feat).cpu()
|
| 147 |
+
return pred
|
| 148 |
+
|
| 149 |
+
def score_clip_batch(model, processor, device, img_paths: list, prompts: list) -> list:
|
| 150 |
+
"""Scores a batch of images against prompts using CLIP."""
|
| 151 |
+
# preprocess
|
| 152 |
+
pil_images = [Image.open(p) for p in img_paths]
|
| 153 |
+
image_inputs = processor(
|
| 154 |
+
images=pil_images,
|
| 155 |
+
padding=True,
|
| 156 |
+
truncation=True,
|
| 157 |
+
max_length=77,
|
| 158 |
+
return_tensors="pt",
|
| 159 |
+
).to(device)
|
| 160 |
+
|
| 161 |
+
text_inputs = processor(
|
| 162 |
+
text=prompts,
|
| 163 |
+
padding=True,
|
| 164 |
+
truncation=True,
|
| 165 |
+
max_length=77,
|
| 166 |
+
return_tensors="pt",
|
| 167 |
+
).to(device)
|
| 168 |
+
|
| 169 |
+
with torch.no_grad():
|
| 170 |
+
# embed
|
| 171 |
+
image_embs = model.get_image_features(**image_inputs)
|
| 172 |
+
image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
|
| 173 |
+
|
| 174 |
+
text_embs = model.get_text_features(**text_inputs)
|
| 175 |
+
text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)
|
| 176 |
+
# score
|
| 177 |
+
scores = image_embs @ text_embs.T
|
| 178 |
+
scores = torch.diagonal(scores).cpu()
|
| 179 |
+
|
| 180 |
+
return scores
|
| 181 |
+
|
| 182 |
+
def calculate_category_stats(data_dict):
|
| 183 |
+
"""Calculate statistics for each category"""
|
| 184 |
+
stats = {}
|
| 185 |
+
for category, data_list in data_dict.items():
|
| 186 |
+
if not data_list:
|
| 187 |
+
stats[category] = {
|
| 188 |
+
'count': 0,
|
| 189 |
+
'mean': 0.0,
|
| 190 |
+
'std': 0.0,
|
| 191 |
+
'min': 0.0,
|
| 192 |
+
'max': 0.0
|
| 193 |
+
}
|
| 194 |
+
continue
|
| 195 |
+
|
| 196 |
+
rewards = [item['reward'] for item in data_list]
|
| 197 |
+
stats[category] = {
|
| 198 |
+
'count': len(rewards),
|
| 199 |
+
'mean': float(np.mean(rewards)),
|
| 200 |
+
'std': float(np.std(rewards)),
|
| 201 |
+
'min': float(np.min(rewards)),
|
| 202 |
+
'max': float(np.max(rewards))
|
| 203 |
+
}
|
| 204 |
+
total_mean = np.mean([stat['mean'] for stat in stats.values() if stat['count'] > 0])
|
| 205 |
+
stats['OVERALL'] = {
|
| 206 |
+
'count': sum(stat['count'] for stat in stats.values()),
|
| 207 |
+
'mean': float(total_mean),
|
| 208 |
+
'std': float(np.std([stat['mean'] for stat in stats.values() if stat['count'] > 0])),
|
| 209 |
+
'min': float(min(stat['min'] for stat in stats.values() if stat['count'] > 0)),
|
| 210 |
+
'max': float(max(stat['max'] for stat in stats.values() if stat['count'] > 0))
|
| 211 |
+
}
|
| 212 |
+
return stats
|
| 213 |
+
|
| 214 |
+
def print_stats(stats):
|
| 215 |
+
print(f"{'Category':<30} {'Count':<8} {'Mean':<10} {'Std':<10} {'Min':<10} {'Max':<10}")
|
| 216 |
+
print("-" * 78)
|
| 217 |
+
for category, stat in stats.items():
|
| 218 |
+
category_name = category # Get folder name only
|
| 219 |
+
print(f"{category_name:<30} {stat['count']:<8} {stat['mean']:<10.4f} {stat['std']:<10.4f} {stat['min']:<10.4f} {stat['max']:<10.4f}")
|
| 220 |
+
|
| 221 |
+
# Calculate overall statistics
|
| 222 |
+
if stats:
|
| 223 |
+
all_counts = [stat['count'] for stat in stats.values()]
|
| 224 |
+
all_means = [stat['mean'] for stat in stats.values() if stat['count'] > 0]
|
| 225 |
+
if all_means:
|
| 226 |
+
print("-" * 78)
|
| 227 |
+
print(f"{'OVERALL':<30} {sum(all_counts):<8} {np.mean(all_means):<10.4f} {'':<10} {min([stat['min'] for stat in stats.values() if stat['count'] > 0]):<10.4f} {max([stat['max'] for stat in stats.values() if stat['count'] > 0]):<10.4f}")
|
| 228 |
+
|
| 229 |
+
def worker_process(process_id, process_dict, config_path, checkpoint_path, mode, device_id, dtype, batch_size, return_dict):
|
| 230 |
+
"""Worker process function that processes a chunk of data"""
|
| 231 |
+
category_rewards = defaultdict(list)
|
| 232 |
+
|
| 233 |
+
device = f"cuda:{device_id}" if torch.cuda.is_available() else "cpu"
|
| 234 |
+
if mode == 'imagereward':
|
| 235 |
+
model = RM.load("ImageReward-v1.0")
|
| 236 |
+
elif mode == 'hpsv2':
|
| 237 |
+
inferencer = initialize_model_hpsv2(device, checkpoint_path)
|
| 238 |
+
model_dict, tokenizer = inferencer
|
| 239 |
+
elif mode == 'hpsv3':
|
| 240 |
+
inferencer = HPSv3RewardInferencer(config_path=config_path, checkpoint_path=checkpoint_path,device=device)
|
| 241 |
+
elif mode == 'pickscore':
|
| 242 |
+
model, processor = initialize_pickscore(device, checkpoint_path)
|
| 243 |
+
elif mode == 'aesthetic':
|
| 244 |
+
model, preprocess, aesthetic_model = initialize_aesthetic_model()
|
| 245 |
+
model = model.to(device)
|
| 246 |
+
aesthetic_model = aesthetic_model.to(device)
|
| 247 |
+
elif mode == 'clip':
|
| 248 |
+
model, processor = initialize_clip(device)
|
| 249 |
+
model = model.to(device)
|
| 250 |
+
else:
|
| 251 |
+
raise ValueError(f"Unsupported mode: {mode}")
|
| 252 |
+
|
| 253 |
+
for category, chunk_data in tqdm(process_dict.items(), total=len(process_dict), desc='Total', disable=not process_id == 0):
|
| 254 |
+
processed_data = []
|
| 255 |
+
# Process data in batches
|
| 256 |
+
for batch_start in tqdm(range(0, len(chunk_data), batch_size),
|
| 257 |
+
total=(len(chunk_data) + batch_size - 1) // batch_size,
|
| 258 |
+
desc=f"Category {category}", disable=not process_id == 0):
|
| 259 |
+
batch_end = min(batch_start + batch_size, len(chunk_data))
|
| 260 |
+
image_paths = chunk_data[batch_start:batch_end]
|
| 261 |
+
text_paths = [p[:-4]+'.txt' for p in image_paths]
|
| 262 |
+
|
| 263 |
+
prompts = ['\n'.join(open(p, 'r').readlines()) for p in text_paths]
|
| 264 |
+
|
| 265 |
+
with torch.no_grad():
|
| 266 |
+
if mode == 'imagereward':
|
| 267 |
+
rewards = torch.tensor([model.score(prompt, image_path) for prompt, image_path in zip(prompts, image_paths)])
|
| 268 |
+
elif mode == 'hpsv2':
|
| 269 |
+
rewards = score_hpsv2_batch(model_dict, tokenizer, device, image_paths, prompts)
|
| 270 |
+
elif mode == 'hpsv3':
|
| 271 |
+
rewards = inferencer.reward(image_paths, prompts)
|
| 272 |
+
elif mode == 'pickscore':
|
| 273 |
+
rewards = score_pick_score_batch(prompts, image_paths, model, processor, device)
|
| 274 |
+
elif mode == 'aesthetic':
|
| 275 |
+
rewards = score_aesthetic_batch(model, preprocess, aesthetic_model, device, image_paths)
|
| 276 |
+
elif mode == 'clip':
|
| 277 |
+
rewards = score_clip_batch(model, processor, device, image_paths, prompts)
|
| 278 |
+
else:
|
| 279 |
+
raise ValueError(f"Unsupported mode: {mode}")
|
| 280 |
+
|
| 281 |
+
torch.cuda.empty_cache()
|
| 282 |
+
for i in range(len(image_paths)):
|
| 283 |
+
if rewards.ndim == 2:
|
| 284 |
+
reward = rewards[i][0].item()
|
| 285 |
+
else:
|
| 286 |
+
reward = rewards[i].item()
|
| 287 |
+
processed_data.append({
|
| 288 |
+
'image_path': image_paths[i],
|
| 289 |
+
'reward': reward,
|
| 290 |
+
'prompt': prompts[i]
|
| 291 |
+
})
|
| 292 |
+
|
| 293 |
+
category_rewards[category] = processed_data
|
| 294 |
+
|
| 295 |
+
return_dict[process_id] = {
|
| 296 |
+
'data': category_rewards,
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
def chunk_list(data_list, num_chunks):
|
| 300 |
+
"""Split list into roughly equal chunks"""
|
| 301 |
+
chunk_size = len(data_list) // num_chunks
|
| 302 |
+
remainder = len(data_list) % num_chunks
|
| 303 |
+
|
| 304 |
+
chunks = []
|
| 305 |
+
start = 0
|
| 306 |
+
for i in range(num_chunks):
|
| 307 |
+
# Add one extra item to first 'remainder' chunks
|
| 308 |
+
current_chunk_size = chunk_size + (1 if i < remainder else 0)
|
| 309 |
+
end = start + current_chunk_size
|
| 310 |
+
chunks.append(data_list[start:end])
|
| 311 |
+
start = end
|
| 312 |
+
|
| 313 |
+
return chunks
|
| 314 |
+
|
| 315 |
+
def main(config_path, checkpint_path, mode, image_folders, output_path, batch_size=16, num_processes=8, num_machines=1, machine_id=0):
|
| 316 |
+
print(f"Config path: {config_path}")
|
| 317 |
+
|
| 318 |
+
dtype = torch.bfloat16
|
| 319 |
+
|
| 320 |
+
# Gather all data first
|
| 321 |
+
folder_dict = {}
|
| 322 |
+
for folder in image_folders:
|
| 323 |
+
images = []
|
| 324 |
+
for ext in ['.png', '.jpg']:
|
| 325 |
+
images.extend(glob.glob(os.path.join(folder, "**", f"*{ext}"), recursive=True))
|
| 326 |
+
machine_image_chunks = chunk_list(images, num_machines)
|
| 327 |
+
image_list = machine_image_chunks[machine_id] if machine_id < len(machine_image_chunks) else []
|
| 328 |
+
print(f"Folder {folder} total data points: {len(image_list)}")
|
| 329 |
+
data_chunks = chunk_list(image_list, num_processes)
|
| 330 |
+
print(f"Folder {folder} data split into {num_processes} chunks with sizes: {[len(chunk) for chunk in data_chunks]}")
|
| 331 |
+
folder_dict[folder] = data_chunks
|
| 332 |
+
|
| 333 |
+
per_process_folder_dict = []
|
| 334 |
+
for i in range(num_processes):
|
| 335 |
+
one_dict = {}
|
| 336 |
+
for key, value in folder_dict.items():
|
| 337 |
+
one_dict[key] = value[i] if i < len(value) else []
|
| 338 |
+
per_process_folder_dict.append(one_dict)
|
| 339 |
+
|
| 340 |
+
# Create manager for shared data between processes
|
| 341 |
+
with mp.Manager() as manager:
|
| 342 |
+
return_dict = manager.dict()
|
| 343 |
+
processes = []
|
| 344 |
+
|
| 345 |
+
# Start processes
|
| 346 |
+
for i in range(num_processes):
|
| 347 |
+
device_id = i % torch.cuda.device_count() if torch.cuda.is_available() else 0
|
| 348 |
+
|
| 349 |
+
p = mp.Process(target=worker_process,
|
| 350 |
+
args=(i, per_process_folder_dict[i], config_path, checkpint_path, mode, device_id, dtype, batch_size, return_dict))
|
| 351 |
+
p.start()
|
| 352 |
+
processes.append(p)
|
| 353 |
+
|
| 354 |
+
for p in processes:
|
| 355 |
+
p.join()
|
| 356 |
+
|
| 357 |
+
# Collect results from all processes
|
| 358 |
+
all_processed_data = {}
|
| 359 |
+
for i in range(num_processes):
|
| 360 |
+
if i in return_dict:
|
| 361 |
+
result = return_dict[i]
|
| 362 |
+
process_data = result['data']
|
| 363 |
+
# Merge data from each process
|
| 364 |
+
for category, data_list in process_data.items():
|
| 365 |
+
if category not in all_processed_data:
|
| 366 |
+
all_processed_data[category] = []
|
| 367 |
+
all_processed_data[category].extend(data_list)
|
| 368 |
+
else:
|
| 369 |
+
print(f"No result from process {i}")
|
| 370 |
+
|
| 371 |
+
# Calculate and print statistics for current machine
|
| 372 |
+
if all_processed_data:
|
| 373 |
+
stats = calculate_category_stats(all_processed_data)
|
| 374 |
+
print(f"\n=== Machine {machine_id} Statistics ===")
|
| 375 |
+
print_stats(stats)
|
| 376 |
+
|
| 377 |
+
# Save results
|
| 378 |
+
if num_machines > 1:
|
| 379 |
+
# Save current machine's results
|
| 380 |
+
machine_output_path = output_path.replace('.json', f'_machine_{machine_id}.json')
|
| 381 |
+
with open(machine_output_path, "w") as f:
|
| 382 |
+
json.dump(all_processed_data, f, indent=4)
|
| 383 |
+
print(f"Machine {machine_id} results saved to {machine_output_path}")
|
| 384 |
+
|
| 385 |
+
# If this is machine 0, try to gather results from all machines
|
| 386 |
+
if machine_id == 0:
|
| 387 |
+
print("Waiting for all machines to complete...")
|
| 388 |
+
# Note: In practice, you might want to implement a proper synchronization mechanism
|
| 389 |
+
# For now, this assumes all machine files exist
|
| 390 |
+
final_result = {}
|
| 391 |
+
for i in range(num_machines):
|
| 392 |
+
machine_file = output_path.replace('.json', f'_machine_{i}.json')
|
| 393 |
+
if os.path.exists(machine_file):
|
| 394 |
+
print(f"Loading results from machine {i}")
|
| 395 |
+
with open(machine_file, 'r') as f:
|
| 396 |
+
machine_data = json.load(f)
|
| 397 |
+
# Merge machine data
|
| 398 |
+
for category, data_list in machine_data.items():
|
| 399 |
+
if category not in final_result:
|
| 400 |
+
final_result[category] = []
|
| 401 |
+
final_result[category].extend(data_list)
|
| 402 |
+
else:
|
| 403 |
+
print(f"Warning: Machine {i} results file not found: {machine_file}")
|
| 404 |
+
|
| 405 |
+
# Calculate and print statistics for final results
|
| 406 |
+
stats = calculate_category_stats(final_result)
|
| 407 |
+
print("\n=== Final Combined Statistics ===")
|
| 408 |
+
print_stats(stats)
|
| 409 |
+
|
| 410 |
+
# Save final combined results with statistics
|
| 411 |
+
final_output = {
|
| 412 |
+
'statistics': stats,
|
| 413 |
+
'data': final_result,
|
| 414 |
+
}
|
| 415 |
+
with open(output_path, "w") as f:
|
| 416 |
+
json.dump(final_output, f, indent=4)
|
| 417 |
+
print(f"Final combined results saved to {output_path}")
|
| 418 |
+
else:
|
| 419 |
+
# Single machine case - calculate statistics
|
| 420 |
+
stats = calculate_category_stats(all_processed_data)
|
| 421 |
+
print("\n=== Statistics ===")
|
| 422 |
+
print_stats(stats)
|
| 423 |
+
|
| 424 |
+
# Save results with statistics
|
| 425 |
+
output_data = {
|
| 426 |
+
'statistics': stats,
|
| 427 |
+
'data': all_processed_data,
|
| 428 |
+
}
|
| 429 |
+
with open(output_path, "w") as f:
|
| 430 |
+
json.dump(output_data, f, indent=4)
|
| 431 |
+
print(f"Results saved to {output_path}")
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
def parse_args():
|
| 435 |
+
parser = argparse.ArgumentParser(description='Process images with HPSv3 reward inference')
|
| 436 |
+
parser.add_argument('--config_path', type=str, help='Path to the configuration file')
|
| 437 |
+
parser.add_argument('--checkpoint_path', type=str, help='Path to the model checkpoint file')
|
| 438 |
+
parser.add_argument('--mode', type=str, choices=['imagereward','hpsv2', 'hpsv3', 'pickscore', 'aesthetic', 'clip'], default='hpsv3')
|
| 439 |
+
parser.add_argument('--image_folders', type=str, nargs='+', required=True, help='List of image folder paths to process')
|
| 440 |
+
parser.add_argument('--output_path', type=str, required=True, help='Path to save the output JSON file')
|
| 441 |
+
parser.add_argument('--batch_size', type=int, default=16, help='Batch size for processing (default: 16)')
|
| 442 |
+
parser.add_argument('--num_processes', type=int, default=8, help='Number of processes to use (default: 8)')
|
| 443 |
+
parser.add_argument('--num_machines', type=int, default=1, help='Total number of machines (default: 1)')
|
| 444 |
+
parser.add_argument('--machine_id', type=int, default=0, help='ID of current machine (default: 0)')
|
| 445 |
+
|
| 446 |
+
return parser.parse_args()
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
if __name__ == "__main__":
|
| 450 |
+
mp.set_start_method('spawn', force=True)
|
| 451 |
+
|
| 452 |
+
args = parse_args()
|
| 453 |
+
main(
|
| 454 |
+
config_path=args.config_path,
|
| 455 |
+
checkpint_path=args.checkpoint_path,
|
| 456 |
+
mode=args.mode,
|
| 457 |
+
image_folders=args.image_folders,
|
| 458 |
+
output_path=args.output_path,
|
| 459 |
+
batch_size=args.batch_size,
|
| 460 |
+
num_processes=args.num_processes,
|
| 461 |
+
num_machines=args.num_machines,
|
| 462 |
+
machine_id=args.machine_id
|
| 463 |
+
)
|
evaluate/evaluate.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
import torch
|
| 5 |
+
import multiprocessing as mp
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from hpsv3.inference import HPSv3RewardInferencer
|
| 8 |
+
from multiprocessing import Process, Queue
|
| 9 |
+
import math
|
| 10 |
+
import fire
|
| 11 |
+
import prettytable
|
| 12 |
+
|
| 13 |
+
def calc_rank_acc(score_sample, predict_sample):
|
| 14 |
+
tol_cnt = 0.
|
| 15 |
+
true_cnt = 0.
|
| 16 |
+
for idx in range(len(score_sample)):
|
| 17 |
+
item_base = score_sample[idx]["ranking"]
|
| 18 |
+
item = predict_sample[idx]["rewards"]
|
| 19 |
+
for i in range(len(item_base)):
|
| 20 |
+
for j in range(i+1, len(item_base)):
|
| 21 |
+
if item_base[i] > item_base[j]:
|
| 22 |
+
if item[i] >= item[j]:
|
| 23 |
+
tol_cnt += 1
|
| 24 |
+
elif item[i] < item[j]:
|
| 25 |
+
tol_cnt += 1
|
| 26 |
+
true_cnt += 1
|
| 27 |
+
elif item_base[i] < item_base[j]:
|
| 28 |
+
if item[i] > item[j]:
|
| 29 |
+
tol_cnt += 1
|
| 30 |
+
true_cnt += 1
|
| 31 |
+
elif item[i] <= item[j]:
|
| 32 |
+
tol_cnt += 1
|
| 33 |
+
return true_cnt / tol_cnt
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def worker_process(process_id, data_chunk, config_path, checkpoint_path, batch_size, result_queue, mode):
|
| 37 |
+
"""
|
| 38 |
+
Worker function for each process to handle a chunk of data
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
# Each process uses a different GPU (cycle through available GPUs)
|
| 42 |
+
num_gpus = torch.cuda.device_count()
|
| 43 |
+
device = f"cuda:{process_id % num_gpus}" if num_gpus > 0 else "cpu"
|
| 44 |
+
dtype = torch.bfloat16
|
| 45 |
+
|
| 46 |
+
print(f"Process {process_id} starting with device {device}, processing {len(data_chunk)} items")
|
| 47 |
+
|
| 48 |
+
# Initialize model for this process
|
| 49 |
+
inferencer = HPSv3RewardInferencer(config_path, checkpoint_path, device=device, dtype=dtype)
|
| 50 |
+
|
| 51 |
+
process_correct = 0
|
| 52 |
+
process_equal = 0
|
| 53 |
+
process_results = []
|
| 54 |
+
|
| 55 |
+
for batch_start in tqdm(range(0, len(data_chunk), batch_size),
|
| 56 |
+
total=(len(data_chunk) + batch_size - 1) // batch_size,
|
| 57 |
+
desc=f"Process {process_id}"):
|
| 58 |
+
batch_end = min(batch_start + batch_size, len(data_chunk))
|
| 59 |
+
batch_info = data_chunk[batch_start:batch_end]
|
| 60 |
+
if mode == 'pair':
|
| 61 |
+
image_paths_1 = [info["path1"] for info in batch_info]
|
| 62 |
+
image_paths_2 = [info["path2"] for info in batch_info]
|
| 63 |
+
prompts = [info["prompt"] for info in batch_info]
|
| 64 |
+
|
| 65 |
+
with torch.no_grad():
|
| 66 |
+
rewards_1 = inferencer.reward(image_paths_1, prompts)
|
| 67 |
+
rewards_2 = inferencer.reward(image_paths_2, prompts)
|
| 68 |
+
|
| 69 |
+
for i in range(len(batch_info)):
|
| 70 |
+
info = batch_info[i]
|
| 71 |
+
if rewards_1.ndim == 2:
|
| 72 |
+
reward_1, reward_2 = rewards_1[i][0].item(), rewards_2[i][0].item()
|
| 73 |
+
else:
|
| 74 |
+
reward_1, reward_2 = rewards_1[i].item(), rewards_2[i].item()
|
| 75 |
+
|
| 76 |
+
item_result = {
|
| 77 |
+
'reward_1': reward_1,
|
| 78 |
+
'reward_2': reward_2,
|
| 79 |
+
'correct': reward_1 > reward_2,
|
| 80 |
+
'equal': reward_1 == reward_2,
|
| 81 |
+
'info': info
|
| 82 |
+
}
|
| 83 |
+
process_results.append(item_result)
|
| 84 |
+
|
| 85 |
+
print(f"Process {process_id} - Reward 1: {reward_1}, Reward 2: {reward_2}")
|
| 86 |
+
if reward_1 > reward_2:
|
| 87 |
+
process_correct += 1
|
| 88 |
+
if reward_1 == reward_2:
|
| 89 |
+
process_equal += 1
|
| 90 |
+
|
| 91 |
+
elif mode == 'ranking':
|
| 92 |
+
for item in batch_info:
|
| 93 |
+
rewards = inferencer.reward(item["generations"], item["prompt"])
|
| 94 |
+
predict_item = {
|
| 95 |
+
"id": item["id"],
|
| 96 |
+
"prompt": item["prompt"],
|
| 97 |
+
"rewards": rewards
|
| 98 |
+
}
|
| 99 |
+
process_results.append(predict_item)
|
| 100 |
+
# Put results in queue
|
| 101 |
+
if mode == 'pair':
|
| 102 |
+
result_queue.put({
|
| 103 |
+
'process_id': process_id,
|
| 104 |
+
'correct': process_correct,
|
| 105 |
+
'equal': process_equal,
|
| 106 |
+
'total': len(data_chunk),
|
| 107 |
+
'results': process_results
|
| 108 |
+
})
|
| 109 |
+
elif mode == 'ranking':
|
| 110 |
+
result_queue.put({
|
| 111 |
+
'process_id': process_id,
|
| 112 |
+
'results': process_results
|
| 113 |
+
})
|
| 114 |
+
|
| 115 |
+
print(f"Process {process_id} completed: {process_correct}/{len(data_chunk)} correct, {process_equal}/{len(data_chunk)} equal")
|
| 116 |
+
|
| 117 |
+
def main(test_json, config_path=None, batch_size=8, num_processes=8, checkpoint_path=None, mode='pair'):
|
| 118 |
+
|
| 119 |
+
assert mode in ['pair', 'ranking'], "Mode must be either 'pair' or 'ranking'"
|
| 120 |
+
assert checkpoint_path is not None, "Checkpoint path must be provided for inference"
|
| 121 |
+
|
| 122 |
+
mp.set_start_method('spawn', force=True)
|
| 123 |
+
|
| 124 |
+
info_list = json.load(open(test_json, "r"))
|
| 125 |
+
|
| 126 |
+
print(f"Total items to process: {len(info_list)}")
|
| 127 |
+
# Split data into chunks for each process
|
| 128 |
+
chunk_size = math.ceil(len(info_list) / num_processes)
|
| 129 |
+
data_chunks = []
|
| 130 |
+
for i in range(num_processes):
|
| 131 |
+
start_idx = i * chunk_size
|
| 132 |
+
end_idx = min((i + 1) * chunk_size, len(info_list))
|
| 133 |
+
if start_idx < len(info_list):
|
| 134 |
+
chunk = info_list[start_idx:end_idx]
|
| 135 |
+
data_chunks.append(chunk)
|
| 136 |
+
print(f"Process {i}: {len(chunk)} items (indices {start_idx}-{end_idx-1})")
|
| 137 |
+
|
| 138 |
+
# Ensure we have the right number of non-empty chunks
|
| 139 |
+
actual_processes = len(data_chunks)
|
| 140 |
+
print(f"Using {actual_processes} processes")
|
| 141 |
+
|
| 142 |
+
# Create result queue and processes
|
| 143 |
+
result_queue = Queue()
|
| 144 |
+
processes = []
|
| 145 |
+
|
| 146 |
+
print("Starting processes...")
|
| 147 |
+
for i in range(actual_processes):
|
| 148 |
+
p = Process(target=worker_process, args=(i, data_chunks[i], config_path, checkpoint_path, batch_size, result_queue, mode))
|
| 149 |
+
p.start()
|
| 150 |
+
processes.append(p)
|
| 151 |
+
|
| 152 |
+
# Wait for all processes to complete and collect results
|
| 153 |
+
all_results = []
|
| 154 |
+
total_correct = 0
|
| 155 |
+
total_equal = 0
|
| 156 |
+
total_items = 0
|
| 157 |
+
|
| 158 |
+
print("Waiting for processes to complete...")
|
| 159 |
+
for i in range(actual_processes):
|
| 160 |
+
result = result_queue.get()
|
| 161 |
+
all_results.append(result)
|
| 162 |
+
if mode == 'pair':
|
| 163 |
+
total_correct += result['correct']
|
| 164 |
+
total_equal += result['equal']
|
| 165 |
+
total_items += result['total']
|
| 166 |
+
|
| 167 |
+
print(f"Process {result['process_id']} finished: {result['correct']}/{result['total']} correct, {result['equal']}/{result['total']} equal")
|
| 168 |
+
|
| 169 |
+
# Wait for all processes to join
|
| 170 |
+
for p in processes:
|
| 171 |
+
p.join()
|
| 172 |
+
|
| 173 |
+
if mode == 'pair':
|
| 174 |
+
aggregated_results = {
|
| 175 |
+
'total_correct': total_correct,
|
| 176 |
+
'total_equal': total_equal,
|
| 177 |
+
'total_items': total_items,
|
| 178 |
+
'accuracy': total_correct / total_items,
|
| 179 |
+
'process_results': all_results
|
| 180 |
+
}
|
| 181 |
+
table = prettytable.PrettyTable()
|
| 182 |
+
table.field_names = ["Total Items", "Correct", "Equal", "Incorrect", "Accuracy (%)"]
|
| 183 |
+
|
| 184 |
+
incorrect = aggregated_results['total_items'] - aggregated_results['total_correct'] - aggregated_results['total_equal']
|
| 185 |
+
accuracy_percent = 100 * aggregated_results['total_correct'] / aggregated_results['total_items']
|
| 186 |
+
|
| 187 |
+
table.add_row([
|
| 188 |
+
aggregated_results['total_items'],
|
| 189 |
+
aggregated_results['total_correct'],
|
| 190 |
+
aggregated_results['total_equal'],
|
| 191 |
+
incorrect,
|
| 192 |
+
f"{accuracy_percent:.2f}"
|
| 193 |
+
])
|
| 194 |
+
elif mode == 'ranking':
|
| 195 |
+
rank_acc = calc_rank_acc(info_list, all_results[0]['results'])
|
| 196 |
+
table = prettytable.PrettyTable()
|
| 197 |
+
table.field_names = ["Total Items", "Rank Accuracy (%)"]
|
| 198 |
+
table.add_row([len(info_list), f"{rank_acc * 100:.2f}"])
|
| 199 |
+
|
| 200 |
+
print(table)
|
| 201 |
+
|
| 202 |
+
if __name__ == "__main__":
|
| 203 |
+
fire.Fire(main)
|
generate/README.md
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Image Generation Module
|
| 2 |
+
|
| 3 |
+
This module is designed for generating images from text prompts using various pretrained diffusion models. It supports parallel generation across multiple GPUs and can be extended to include new models easily.
|
| 4 |
+
|
| 5 |
+
## File Structure
|
| 6 |
+
|
| 7 |
+
- `gen_images_from_prompt.py`: The main script for running the image generation process. It reads prompts from a JSON file and handles command-line arguments.
|
| 8 |
+
- `generator.py`: Contains the core `Generator` class, which manages the model pipelines and distributes the generation tasks across different devices.
|
| 9 |
+
- `utils/pipelines.py`: Defines the configurations for all supported pretrained models. This is where you can add or modify model parameters.
|
| 10 |
+
- `utils/utils.py`: Contains helper functions for initializing `diffusers` pipelines and interacting with model APIs.
|
| 11 |
+
|
| 12 |
+
## How to Use
|
| 13 |
+
|
| 14 |
+
To generate images, run the main script with the required arguments.
|
| 15 |
+
|
| 16 |
+
### Basic Command
|
| 17 |
+
|
| 18 |
+
```bash
|
| 19 |
+
python gen_images_from_prompt.py \
|
| 20 |
+
--json_path /path/to/your/prompts.json \
|
| 21 |
+
--out_dir /path/to/your/output_directory \
|
| 22 |
+
--pipeline_name sd_xl_pipe flux_schnell_pipe
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
### Command-Line Arguments
|
| 26 |
+
|
| 27 |
+
- `--json_path` (required): Path to a JSON file containing a list of prompts. Each item in the list should be an object with a `"caption"` key.
|
| 28 |
+
|
| 29 |
+
- For generating images according to real images, you should specify `"image_file"` which is the original image path, and `"aspect_ratio"` of this image. The specific height and width will be adjusted according to model's best practice resolution.
|
| 30 |
+
|
| 31 |
+
- For generating images from prompt only, you should specify `"save_name"`, `"height"` and `"width"`
|
| 32 |
+
**Example `prompts.json` format:**
|
| 33 |
+
```json
|
| 34 |
+
[
|
| 35 |
+
{
|
| 36 |
+
"image_file": "1.jpg",
|
| 37 |
+
"caption": "A beautiful landscape painting of a mountain range at sunset.",
|
| 38 |
+
"aspect_ratio": 0.5,
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
"image_file": "2.jpg",
|
| 42 |
+
"caption": "A close-up photo of a red rose with water droplets.",
|
| 43 |
+
"aspect_ratio": 1.0,
|
| 44 |
+
},
|
| 45 |
+
{
|
| 46 |
+
"image_file": "3.jpg",
|
| 47 |
+
"caption": "An astronaut riding a horse on Mars, digital art.",
|
| 48 |
+
"aspect_ratio": 1.77,
|
| 49 |
+
}
|
| 50 |
+
]
|
| 51 |
+
```
|
| 52 |
+
- `--out_dir` (required): The root directory where generated images will be saved. A subdirectory will be created for each pipeline.
|
| 53 |
+
- `--pipeline_name` (required): One or more pipeline configuration names to use for generation. These names must correspond to the `PipelineParam` variable names defined in `utils/pipelines.py`.
|
| 54 |
+
- `--num_devices`: The number of GPU devices to use for generation. Defaults to `8`.
|
| 55 |
+
- `--batch_size`: The batch size per device. Defaults to `1`.
|
| 56 |
+
- `--num_machine`: The total number of machines used in a distributed setup. Defaults to `1`.
|
| 57 |
+
- `--machine_id`: The ID of the current machine in a distributed setup. Defaults to `0`.
|
| 58 |
+
- `--enable_availabel_check`: If set, the script will first run a quick check on a small batch to ensure each pipeline can be loaded and run without errors.
|
| 59 |
+
- `--reverse`: If set, the order of the specified pipelines will be reversed.
|
| 60 |
+
|
| 61 |
+
## How to Add a New Model
|
| 62 |
+
|
| 63 |
+
You can easily add a new text-to-image model by configuring it in the `utils/pipelines.py` file.
|
| 64 |
+
|
| 65 |
+
1. **Open `utils/pipelines.py`**.
|
| 66 |
+
2. **Import `PipelineParam`** if it's not already imported.
|
| 67 |
+
3. **Create a new `PipelineParam` instance** for your model. Define the following parameters:
|
| 68 |
+
- `pipeline_name`: The model's path on the Hugging Face Hub or a local directory.
|
| 69 |
+
- `generation_path`: The name of the subdirectory where the output images will be saved.
|
| 70 |
+
- `pipeline_type`: The type of pipeline, e.g., `'t2i'` (text-to-image) or `'t2v'` (text-to-video). Defaults to `'t2i'`.
|
| 71 |
+
- `pipe_init_kwargs`: A dictionary of arguments required for initializing the model pipeline (e.g., `{"torch_dtype": torch.float16}`).
|
| 72 |
+
- `generation_kwargs`: A dictionary of arguments for the generation process (e.g., `{"guidance_scale": 7.0, "num_inference_steps": 28}`).
|
| 73 |
+
- `base_resolution`: The base resolution the model was trained on (e.g., `1024`).
|
| 74 |
+
- `force_aspect_ratio`: Optionally force a specific aspect ratio (e.g., `1` for square images).
|
| 75 |
+
|
| 76 |
+
**Example:**
|
| 77 |
+
|
| 78 |
+
```python
|
| 79 |
+
from pydantic import BaseModel, Field
|
| 80 |
+
import torch
|
| 81 |
+
|
| 82 |
+
class PipelineParam(BaseModel):
|
| 83 |
+
# ... (class definition)
|
| 84 |
+
|
| 85 |
+
# Add your new model configuration
|
| 86 |
+
my_new_model_pipe = PipelineParam(
|
| 87 |
+
pipeline_name='organization/my-cool-model',
|
| 88 |
+
generation_path=f'generation/my_cool_model',
|
| 89 |
+
pipe_init_kwargs={
|
| 90 |
+
"torch_dtype": torch.float16,
|
| 91 |
+
},
|
| 92 |
+
base_resolution=1024,
|
| 93 |
+
generation_kwargs={
|
| 94 |
+
"guidance_scale": 5.0,
|
| 95 |
+
"num_inference_steps": 30,
|
| 96 |
+
}
|
| 97 |
+
)
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
4. **Run the generation script** using the name of your new `PipelineParam` variable in the `--pipeline_name` argument.
|
| 101 |
+
|
| 102 |
+
```bash
|
| 103 |
+
python gen_images_from_prompt.py --pipeline_name my_new_model_pipe ...
|
| 104 |
+
```
|
generate/__init__.py
ADDED
|
File without changes
|
generate/gen_images_from_prompt.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from generator import Generator
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import torch
|
| 5 |
+
import gc
|
| 6 |
+
from utils.pipelines import *
|
| 7 |
+
import argparse
|
| 8 |
+
|
| 9 |
+
def parse_args():
|
| 10 |
+
parser = argparse.ArgumentParser(description="生成图片")
|
| 11 |
+
parser.add_argument(
|
| 12 |
+
"--json_path",
|
| 13 |
+
type=str,
|
| 14 |
+
help="json路径",
|
| 15 |
+
)
|
| 16 |
+
parser.add_argument(
|
| 17 |
+
"--out_dir",
|
| 18 |
+
type=str,
|
| 19 |
+
help="输出目录",
|
| 20 |
+
)
|
| 21 |
+
parser.add_argument("--num_devices", type=int, default=8, help="设备数量")
|
| 22 |
+
parser.add_argument("--batch_size", type=int, default=1, help="批量大小")
|
| 23 |
+
parser.add_argument("--num_machine", type=int, default=1, help="机器数量")
|
| 24 |
+
parser.add_argument("--machine_id", type=int, default=0, help="机器id")
|
| 25 |
+
parser.add_argument(
|
| 26 |
+
"--pipeline_name", type=str, nargs="+", default=None, help="pipeline名称"
|
| 27 |
+
)
|
| 28 |
+
parser.add_argument("--enable_availabel_check", action="store_true")
|
| 29 |
+
parser.add_argument("--reverse", action="store_true")
|
| 30 |
+
return parser.parse_args()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def main():
|
| 34 |
+
args = parse_args()
|
| 35 |
+
num_devices = args.num_devices
|
| 36 |
+
pipeline_params = [globals()[f"{name}_pipe"] for name in args.pipeline_name]
|
| 37 |
+
|
| 38 |
+
if args.reverse:
|
| 39 |
+
pipeline_params = pipeline_params[::-1]
|
| 40 |
+
|
| 41 |
+
# first check all pipeline
|
| 42 |
+
if args.enable_availabel_check:
|
| 43 |
+
print(f"Checking {len(pipeline_params)} pipelines")
|
| 44 |
+
for pipeline_param in pipeline_params:
|
| 45 |
+
generator = Generator(
|
| 46 |
+
pipe_name=pipeline_param.pipeline_name,
|
| 47 |
+
pipe_type=pipeline_param.pipeline_type,
|
| 48 |
+
pipe_init_kwargs=pipeline_param.pipe_init_kwargs,
|
| 49 |
+
num_devices=num_devices,
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
with open(args.json_path, "r") as f:
|
| 53 |
+
entries = json.load(f)
|
| 54 |
+
info_dict = entries[: args.batch_size]
|
| 55 |
+
generator.generate(
|
| 56 |
+
info_dict,
|
| 57 |
+
os.path.join(args.out_dir, pipeline_param.generation_path),
|
| 58 |
+
batch_size=args.batch_size,
|
| 59 |
+
num_processes=num_devices,
|
| 60 |
+
seed=42,
|
| 61 |
+
weight_dtype=pipeline_param.pipe_init_kwargs["torch_dtype"],
|
| 62 |
+
generation_kwargs=pipeline_param.generation_kwargs,
|
| 63 |
+
base_resolution=pipeline_param.base_resolution,
|
| 64 |
+
force_aspect_ratio=pipeline_param.force_aspect_ratio,
|
| 65 |
+
)
|
| 66 |
+
del generator
|
| 67 |
+
gc.collect()
|
| 68 |
+
torch.cuda.empty_cache()
|
| 69 |
+
print(f"Finished Checking {pipeline_param.pipeline_name}")
|
| 70 |
+
|
| 71 |
+
for pipeline_param in pipeline_params:
|
| 72 |
+
generator = Generator(
|
| 73 |
+
pipe_name=pipeline_param.pipeline_name,
|
| 74 |
+
pipe_type=pipeline_param.pipeline_type,
|
| 75 |
+
pipe_init_kwargs=pipeline_param.pipe_init_kwargs,
|
| 76 |
+
num_devices=num_devices,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
with open(args.json_path, "r") as f:
|
| 80 |
+
entries = json.load(f)
|
| 81 |
+
|
| 82 |
+
for i in range(args.num_machine):
|
| 83 |
+
start_idx = i * len(entries) // args.num_machine
|
| 84 |
+
end_idx = (
|
| 85 |
+
(i + 1) * len(entries) // args.num_machine
|
| 86 |
+
if i != args.num_machine - 1
|
| 87 |
+
else len(entries)
|
| 88 |
+
)
|
| 89 |
+
if i == args.machine_id:
|
| 90 |
+
info_dict = entries[start_idx:end_idx]
|
| 91 |
+
|
| 92 |
+
info_dict = sorted(info_dict, key=lambda x: x["aspect_ratio"])
|
| 93 |
+
|
| 94 |
+
print(f"Generating {len(info_dict)} images")
|
| 95 |
+
generator.generate(
|
| 96 |
+
info_dict,
|
| 97 |
+
os.path.join(args.out_dir, pipeline_param.generation_path),
|
| 98 |
+
batch_size=args.batch_size,
|
| 99 |
+
num_processes=num_devices,
|
| 100 |
+
seed=42,
|
| 101 |
+
weight_dtype=pipeline_param.pipe_init_kwargs["torch_dtype"],
|
| 102 |
+
generation_kwargs=pipeline_param.generation_kwargs,
|
| 103 |
+
base_resolution=pipeline_param.base_resolution,
|
| 104 |
+
force_aspect_ratio=pipeline_param.force_aspect_ratio,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
print(f"Finished generating {pipeline_param.pipeline_name}")
|
| 108 |
+
|
| 109 |
+
for pipeline in generator.pipelines:
|
| 110 |
+
pipeline.to("cpu")
|
| 111 |
+
del generator
|
| 112 |
+
torch.cuda.empty_cache()
|
| 113 |
+
gc.collect()
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
if __name__ == "__main__":
|
| 117 |
+
main()
|
generate/generator.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import os
|
| 3 |
+
import inspect
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
from utils.utils import init_multiple_pipelines
|
| 7 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 8 |
+
|
| 9 |
+
Image.MAX_IMAGE_PIXELS = None
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Generator:
|
| 13 |
+
def __init__(
|
| 14 |
+
self, pipe_name, pipe_type, pipe_init_kwargs, num_devices, device_id=None
|
| 15 |
+
):
|
| 16 |
+
self.pipe_names = pipe_name
|
| 17 |
+
self.pipe_type = pipe_type
|
| 18 |
+
self.pipe_init_kwargs = pipe_init_kwargs
|
| 19 |
+
self.pipelines = init_multiple_pipelines(
|
| 20 |
+
pipe_name, pipe_init_kwargs, num_devices, device_id
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
def generate_imgs(
|
| 24 |
+
self,
|
| 25 |
+
num_device,
|
| 26 |
+
batch_size,
|
| 27 |
+
generation_path,
|
| 28 |
+
info_dict,
|
| 29 |
+
pipeline,
|
| 30 |
+
device_id,
|
| 31 |
+
weight_dtype,
|
| 32 |
+
seed,
|
| 33 |
+
base_resolution,
|
| 34 |
+
force_aspect_ratio,
|
| 35 |
+
generation_kwargs,
|
| 36 |
+
|
| 37 |
+
):
|
| 38 |
+
|
| 39 |
+
torch.cuda.set_device(f"cuda:{device_id%num_device}")
|
| 40 |
+
device = torch.device(f"cuda:{device_id%num_device}")
|
| 41 |
+
|
| 42 |
+
num_prompts_per_device = len(info_dict) // num_device
|
| 43 |
+
start_idx = device_id * num_prompts_per_device
|
| 44 |
+
end_idx = (
|
| 45 |
+
start_idx + num_prompts_per_device
|
| 46 |
+
if device_id != (num_device - 1)
|
| 47 |
+
else len(info_dict)
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
device_info_dict = info_dict[start_idx:end_idx]
|
| 51 |
+
|
| 52 |
+
print(f"Device {device} generating for prompts {start_idx} to {end_idx-1}")
|
| 53 |
+
|
| 54 |
+
print("## Prepare generation dataset")
|
| 55 |
+
|
| 56 |
+
total_batches = len(device_info_dict) // batch_size + (
|
| 57 |
+
1 if len(device_info_dict) % batch_size != 0 else 0
|
| 58 |
+
)
|
| 59 |
+
for batch_idx in tqdm(
|
| 60 |
+
range(total_batches), desc="Pipeline: " + self.pipe_names
|
| 61 |
+
):
|
| 62 |
+
batch_info_dict = device_info_dict[
|
| 63 |
+
batch_idx * batch_size : (batch_idx + 1) * batch_size
|
| 64 |
+
]
|
| 65 |
+
save_paths = []
|
| 66 |
+
for info_dict in batch_info_dict:
|
| 67 |
+
if info_dict["image_file"] is not None:
|
| 68 |
+
save_paths.append(
|
| 69 |
+
os.path.join(generation_path, info_dict["image_file"][:-4] + ".png")
|
| 70 |
+
)
|
| 71 |
+
else:
|
| 72 |
+
save_paths.append(
|
| 73 |
+
os.path.join(generation_path, info_dict["save_name"] + ".png")
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
exists_idx = []
|
| 77 |
+
for i, save_path in enumerate(save_paths):
|
| 78 |
+
if os.path.exists(save_path):
|
| 79 |
+
exists_idx.append(i)
|
| 80 |
+
|
| 81 |
+
batch_info_dict = [
|
| 82 |
+
batch_info_dict[i]
|
| 83 |
+
for i in range(len(batch_info_dict))
|
| 84 |
+
if i not in exists_idx
|
| 85 |
+
]
|
| 86 |
+
if len(batch_info_dict) == 0:
|
| 87 |
+
continue
|
| 88 |
+
|
| 89 |
+
batch_prompts = [info_dict["caption"] for info_dict in batch_info_dict]
|
| 90 |
+
batch_image_file = [
|
| 91 |
+
info_dict["image_file"] for info_dict in batch_info_dict
|
| 92 |
+
]
|
| 93 |
+
if batch_image_file[0] is not None:
|
| 94 |
+
try:
|
| 95 |
+
batch_image_sizes = [
|
| 96 |
+
Image.open(image_file).size for image_file in batch_image_file
|
| 97 |
+
]
|
| 98 |
+
except:
|
| 99 |
+
batch_image_sizes = None
|
| 100 |
+
else:
|
| 101 |
+
batch_image_sizes = [
|
| 102 |
+
(batch_info_dict[i]["width"], batch_info_dict[i]["height"])
|
| 103 |
+
for i in range(len(batch_info_dict))
|
| 104 |
+
]
|
| 105 |
+
|
| 106 |
+
if batch_image_sizes is None:
|
| 107 |
+
aspect_ratios = [
|
| 108 |
+
info_dict["aspect_ratio"] for info_dict in batch_info_dict
|
| 109 |
+
]
|
| 110 |
+
else:
|
| 111 |
+
aspect_ratios = [size[0] / size[1] for size in batch_image_sizes]
|
| 112 |
+
|
| 113 |
+
if force_aspect_ratio:
|
| 114 |
+
height = int(base_resolution / force_aspect_ratio // 64 * 64)
|
| 115 |
+
width = int(base_resolution * force_aspect_ratio // 64 * 64)
|
| 116 |
+
else:
|
| 117 |
+
# 根据aspect_ratios调整base_resolution, 得到height和width, 保证调整后的乘积大概等于base_resolution**2
|
| 118 |
+
height = int(base_resolution / aspect_ratios[0] ** (0.5) // 64 * 64)
|
| 119 |
+
width = int(base_resolution * aspect_ratios[0] ** (0.5) // 64 * 64)
|
| 120 |
+
generation_kwargs.update({"height": height, "width": width})
|
| 121 |
+
|
| 122 |
+
generator = torch.Generator().manual_seed(seed + batch_idx)
|
| 123 |
+
|
| 124 |
+
pipeline_signature = inspect.signature(pipeline)
|
| 125 |
+
pipeline_params = pipeline_signature.parameters.keys()
|
| 126 |
+
|
| 127 |
+
if 'height' not in pipeline_params:
|
| 128 |
+
generation_kwargs.pop('height', None)
|
| 129 |
+
print(f"Warning: Pipeline does not support 'height' parameter, removing from kwargs")
|
| 130 |
+
if 'width' not in pipeline_params:
|
| 131 |
+
generation_kwargs.pop('width', None)
|
| 132 |
+
print(f"Warning: Pipeline does not support 'width' parameter, removing from kwargs")
|
| 133 |
+
|
| 134 |
+
try:
|
| 135 |
+
outputs = pipeline(
|
| 136 |
+
prompt=batch_prompts, generator=generator, **generation_kwargs
|
| 137 |
+
)
|
| 138 |
+
except Exception as e:
|
| 139 |
+
print(e)
|
| 140 |
+
continue
|
| 141 |
+
if self.pipe_type == "t2i":
|
| 142 |
+
images = outputs.images
|
| 143 |
+
elif self.pipe_type == "t2v":
|
| 144 |
+
images = outputs.frames[0]
|
| 145 |
+
|
| 146 |
+
for img_idx, (img, prompt, image_file, info_dict) in enumerate(
|
| 147 |
+
zip(images, batch_prompts, batch_image_file, batch_info_dict)
|
| 148 |
+
):
|
| 149 |
+
if image_file is None:
|
| 150 |
+
img_path = os.path.join(
|
| 151 |
+
generation_path, info_dict["save_name"] + ".png"
|
| 152 |
+
)
|
| 153 |
+
else:
|
| 154 |
+
img_path = generation_path + image_file[:-4] + ".png"
|
| 155 |
+
|
| 156 |
+
if not os.path.exists(os.path.dirname(img_path)):
|
| 157 |
+
os.makedirs(os.path.dirname(img_path), exist_ok=True)
|
| 158 |
+
img.save(img_path)
|
| 159 |
+
if image_file is None:
|
| 160 |
+
text_path = os.path.join(
|
| 161 |
+
generation_path, info_dict["save_name"] + ".txt"
|
| 162 |
+
)
|
| 163 |
+
else:
|
| 164 |
+
text_path = generation_path + image_file[:-4] + ".txt"
|
| 165 |
+
try:
|
| 166 |
+
with open(text_path, "w") as f:
|
| 167 |
+
f.write(prompt)
|
| 168 |
+
f.write("\n")
|
| 169 |
+
f.write(
|
| 170 |
+
image_file
|
| 171 |
+
if image_file is not None
|
| 172 |
+
else info_dict["save_name"]
|
| 173 |
+
)
|
| 174 |
+
except:
|
| 175 |
+
pass
|
| 176 |
+
return True
|
| 177 |
+
|
| 178 |
+
def generate(
|
| 179 |
+
self,
|
| 180 |
+
info_dict,
|
| 181 |
+
generation_path,
|
| 182 |
+
num_processes,
|
| 183 |
+
batch_size,
|
| 184 |
+
weight_dtype,
|
| 185 |
+
seed,
|
| 186 |
+
generation_kwargs,
|
| 187 |
+
base_resolution,
|
| 188 |
+
force_aspect_ratio,
|
| 189 |
+
):
|
| 190 |
+
|
| 191 |
+
with ThreadPoolExecutor(max_workers=num_processes) as executor:
|
| 192 |
+
futures = [
|
| 193 |
+
executor.submit(
|
| 194 |
+
self.generate_imgs,
|
| 195 |
+
num_processes,
|
| 196 |
+
batch_size,
|
| 197 |
+
generation_path,
|
| 198 |
+
info_dict,
|
| 199 |
+
self.pipelines[device_id],
|
| 200 |
+
device_id,
|
| 201 |
+
weight_dtype,
|
| 202 |
+
seed,
|
| 203 |
+
base_resolution,
|
| 204 |
+
force_aspect_ratio,
|
| 205 |
+
generation_kwargs,
|
| 206 |
+
)
|
| 207 |
+
for device_id in range(num_processes)
|
| 208 |
+
]
|
| 209 |
+
|
| 210 |
+
for future in as_completed(futures):
|
| 211 |
+
print(f"Task completed: {future.result()}")
|
generate/utils/__init__.py
ADDED
|
File without changes
|
generate/utils/pipelines.py
ADDED
|
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from pydantic import BaseModel, Field
|
| 3 |
+
from typing import Optional, Dict, Any
|
| 4 |
+
|
| 5 |
+
class PipelineParam(BaseModel):
|
| 6 |
+
pipeline_name: str
|
| 7 |
+
generation_path: str
|
| 8 |
+
pipeline_type: str = 't2i'
|
| 9 |
+
pipe_init_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
| 10 |
+
generation_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
| 11 |
+
base_resolution: int = 1024
|
| 12 |
+
force_aspect_ratio: Optional[int] = None
|
| 13 |
+
|
| 14 |
+
flux_dev_pipe = PipelineParam(
|
| 15 |
+
pipeline_name='pretrained_models/FLUX.1-dev',
|
| 16 |
+
generation_path=f'generation/flux_dev',
|
| 17 |
+
pipe_init_kwargs={
|
| 18 |
+
"torch_dtype": torch.bfloat16,
|
| 19 |
+
},
|
| 20 |
+
base_resolution=1024,
|
| 21 |
+
generation_kwargs={
|
| 22 |
+
"guidance_scale": 3.5,
|
| 23 |
+
"num_inference_steps": 28,
|
| 24 |
+
"max_sequence_length": 512,
|
| 25 |
+
}
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
flux_schnell_pipe = PipelineParam(
|
| 29 |
+
pipeline_name='pretrained_models/FLUX.1-schnell',
|
| 30 |
+
generation_path=f'generation/flux_schnell',
|
| 31 |
+
pipe_init_kwargs={
|
| 32 |
+
"torch_dtype": torch.bfloat16,
|
| 33 |
+
},
|
| 34 |
+
base_resolution=1024,
|
| 35 |
+
generation_kwargs={
|
| 36 |
+
"guidance_scale": 3.5,
|
| 37 |
+
"num_inference_steps": 4,
|
| 38 |
+
}
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
sd3_medium_pipe = PipelineParam(
|
| 43 |
+
pipeline_name='pretrained_models/stable-diffusion-3-medium-diffusers',
|
| 44 |
+
generation_path=f'generation/sd3_medium',
|
| 45 |
+
pipe_init_kwargs={
|
| 46 |
+
"torch_dtype": torch.float16,
|
| 47 |
+
},
|
| 48 |
+
base_resolution=1024,
|
| 49 |
+
generation_kwargs={
|
| 50 |
+
"guidance_scale": 7.0,
|
| 51 |
+
"num_inference_steps": 28,
|
| 52 |
+
}
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
sd_xl_pipe = PipelineParam(
|
| 56 |
+
pipeline_name='pretrained_models/stable-diffusion-xl-base-1.0',
|
| 57 |
+
generation_path=f'generation/sd_xl',
|
| 58 |
+
pipe_init_kwargs={
|
| 59 |
+
"torch_dtype": torch.float16,
|
| 60 |
+
},
|
| 61 |
+
base_resolution=1024,
|
| 62 |
+
generation_kwargs={
|
| 63 |
+
"guidance_scale": 5,
|
| 64 |
+
"num_inference_steps": 50,
|
| 65 |
+
}
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
sd_1_5_pipe = PipelineParam(
|
| 69 |
+
pipeline_name='pretrained_models/stable-diffusion-v1-5',
|
| 70 |
+
generation_path=f'generation/sd_1_5',
|
| 71 |
+
pipe_init_kwargs={
|
| 72 |
+
"torch_dtype": torch.float16,
|
| 73 |
+
},
|
| 74 |
+
base_resolution=512,
|
| 75 |
+
generation_kwargs={
|
| 76 |
+
}
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
vq_diffusion_pipe = PipelineParam(
|
| 80 |
+
pipeline_name='pretrained_models/vq-diffusion-ithq',
|
| 81 |
+
generation_path=f'generation/vq_diffusion',
|
| 82 |
+
pipe_init_kwargs={
|
| 83 |
+
"torch_dtype": torch.float16,
|
| 84 |
+
},
|
| 85 |
+
base_resolution=256,
|
| 86 |
+
generation_kwargs={}
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
sd_2_pipe = PipelineParam(
|
| 90 |
+
pipeline_name='pretrained_models/stable-diffusion-2',
|
| 91 |
+
generation_path=f'generation/sd_2',
|
| 92 |
+
pipe_init_kwargs={
|
| 93 |
+
"torch_dtype": torch.float16,
|
| 94 |
+
},
|
| 95 |
+
base_resolution=512,
|
| 96 |
+
force_aspect_ratio=1,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
sd_1_1_pipe = PipelineParam(
|
| 100 |
+
pipeline_name='pretrained_models/stable-diffusion-v1-1',
|
| 101 |
+
generation_path=f'generation/sd_1_1',
|
| 102 |
+
pipe_init_kwargs={"torch_dtype": torch.float16,},
|
| 103 |
+
base_resolution=512,
|
| 104 |
+
force_aspect_ratio=1,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
sd_1_4_pipe = PipelineParam(
|
| 108 |
+
pipeline_name='pretrained_models/stable-diffusion-v1-4',
|
| 109 |
+
generation_path=f'generation/sd_1_4',
|
| 110 |
+
pipe_init_kwargs={
|
| 111 |
+
"torch_dtype": torch.float16,
|
| 112 |
+
},
|
| 113 |
+
base_resolution=512,
|
| 114 |
+
force_aspect_ratio=1,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
sd_2_1_pipe = PipelineParam(
|
| 118 |
+
pipeline_name='pretrained_models/stable-diffusion-2-1-base',
|
| 119 |
+
generation_path=f'generation/sd_2_1',
|
| 120 |
+
pipe_init_kwargs={
|
| 121 |
+
"torch_dtype": torch.float16,
|
| 122 |
+
},
|
| 123 |
+
base_resolution=512,
|
| 124 |
+
force_aspect_ratio=1,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
openjourney_pipe = PipelineParam(
|
| 128 |
+
pipeline_name='pretrained_models/openjourney',
|
| 129 |
+
generation_path=f'generation/openjourney',
|
| 130 |
+
pipe_init_kwargs={
|
| 131 |
+
"torch_dtype": torch.float16,
|
| 132 |
+
},
|
| 133 |
+
base_resolution=512,
|
| 134 |
+
force_aspect_ratio=1,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
playground_v2_5_pipe = PipelineParam(
|
| 138 |
+
pipeline_name='pretrained_models/playground-v2.5-1024px-aesthetic',
|
| 139 |
+
generation_path=f'generation/playground_v_2_5',
|
| 140 |
+
pipe_init_kwargs={
|
| 141 |
+
"torch_dtype": torch.float16,
|
| 142 |
+
},
|
| 143 |
+
base_resolution=1024,
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
versatile_pipe = PipelineParam(
|
| 147 |
+
pipeline_name='pretrained_models/versatile-diffusion',
|
| 148 |
+
generation_path=f'generation/versatile',
|
| 149 |
+
pipe_init_kwargs={
|
| 150 |
+
"torch_dtype": torch.float16,
|
| 151 |
+
},
|
| 152 |
+
base_resolution=512,
|
| 153 |
+
force_aspect_ratio=1,
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
glide_pipe = PipelineParam(
|
| 157 |
+
pipeline_name='pretrained_models/glide-base',
|
| 158 |
+
generation_path=f'generation/glide',
|
| 159 |
+
pipe_init_kwargs={
|
| 160 |
+
"torch_dtype": torch.float16,
|
| 161 |
+
},
|
| 162 |
+
base_resolution=512,
|
| 163 |
+
force_aspect_ratio=1,
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
sd_3_5_medium_pipe = PipelineParam(
|
| 167 |
+
pipeline_name='stabilityai/stable-diffusion-3.5-medium',
|
| 168 |
+
generation_path=f'generation/sd_3_5_medium',
|
| 169 |
+
pipe_init_kwargs={
|
| 170 |
+
"torch_dtype": torch.bfloat16,
|
| 171 |
+
},
|
| 172 |
+
base_resolution=1024,
|
| 173 |
+
generation_kwargs={
|
| 174 |
+
"num_inference_steps": 40,
|
| 175 |
+
"guidance_scale": 4.5,
|
| 176 |
+
}
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
sd_3_5_large_pipe = PipelineParam(
|
| 180 |
+
pipeline_name='stabilityai/stable-diffusion-3.5-large',
|
| 181 |
+
generation_path=f'generation/sd_3_5_large',
|
| 182 |
+
pipe_init_kwargs={
|
| 183 |
+
"torch_dtype": torch.bfloat16,
|
| 184 |
+
},
|
| 185 |
+
base_resolution=1024,
|
| 186 |
+
generation_kwargs={
|
| 187 |
+
"num_inference_steps": 28,
|
| 188 |
+
"guidance_scale": 3.5,
|
| 189 |
+
}
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
kolors_pipe = PipelineParam(
|
| 193 |
+
pipeline_name='pretrained_models/Kolors-diffusers',
|
| 194 |
+
generation_path=f'generation/kolors',
|
| 195 |
+
pipe_init_kwargs={
|
| 196 |
+
"torch_dtype": torch.float16,
|
| 197 |
+
'variant': 'fp16',
|
| 198 |
+
},
|
| 199 |
+
base_resolution=1024,
|
| 200 |
+
generation_kwargs={
|
| 201 |
+
"num_inference_steps": 50,
|
| 202 |
+
"guidance_scale": 5.0,
|
| 203 |
+
}
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
cogview4_pipe = PipelineParam(
|
| 207 |
+
pipeline_name='pretrained_models/CogView4-6B',
|
| 208 |
+
generation_path=f'generation/cogview4',
|
| 209 |
+
pipe_init_kwargs={
|
| 210 |
+
"torch_dtype": torch.bfloat16,
|
| 211 |
+
},
|
| 212 |
+
base_resolution=1024,
|
| 213 |
+
generation_kwargs={
|
| 214 |
+
"num_inference_steps": 50,
|
| 215 |
+
"guidance_scale": 3.5,
|
| 216 |
+
}
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
pixart_sigma_pipe = PipelineParam(
|
| 220 |
+
pipeline_name='pretrained_models/PixArt-Sigma-XL-2-1024-MS',
|
| 221 |
+
generation_path=f'generation/pixart_sigma',
|
| 222 |
+
pipeline_type='t2i',
|
| 223 |
+
pipe_init_kwargs={
|
| 224 |
+
"torch_dtype": torch.bfloat16,
|
| 225 |
+
},
|
| 226 |
+
base_resolution=1024,
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
hunyuanvideo_pipe = PipelineParam(
|
| 230 |
+
pipeline_name='pretrained_models/hunyuanvideo_diffusers',
|
| 231 |
+
generation_path=f'generation/hunyuanvideo',
|
| 232 |
+
pipe_init_kwargs={
|
| 233 |
+
"torch_dtype": torch.bfloat16,
|
| 234 |
+
},
|
| 235 |
+
base_resolution=1024,
|
| 236 |
+
pipeline_type='t2v',
|
| 237 |
+
generation_kwargs={
|
| 238 |
+
"num_inference_steps": 30,
|
| 239 |
+
"num_frames": 1,
|
| 240 |
+
}
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
hunyuandit_pipe = PipelineParam(
|
| 244 |
+
pipeline_name='pretrained_models/HunyuanDiT-v1.2-Diffusers',
|
| 245 |
+
generation_path=f'generation/hunyuandit',
|
| 246 |
+
pipe_init_kwargs={
|
| 247 |
+
"torch_dtype": torch.float16,
|
| 248 |
+
},
|
| 249 |
+
base_resolution=1024,
|
| 250 |
+
pipeline_type='t2i',
|
| 251 |
+
generation_kwargs={
|
| 252 |
+
}
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
# API models
|
| 256 |
+
# Fal.ai
|
| 257 |
+
flux_pro_v1_1_ultr_pipe = PipelineParam(
|
| 258 |
+
pipeline_name='fal-ai/flux-pro/v1.1-ultra',
|
| 259 |
+
generation_path=f'generation/flux_pro_v1_1_ultra',
|
| 260 |
+
base_resolution=1024,
|
| 261 |
+
generation_kwargs={
|
| 262 |
+
"enable_safety_checker": False,
|
| 263 |
+
"num_images": 1,
|
| 264 |
+
# "aspect_ratio": "1:1",
|
| 265 |
+
"output_format": "jpeg",
|
| 266 |
+
"safety_tolerance": 5,
|
| 267 |
+
}
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
recraftv3_pipe = PipelineParam(
|
| 271 |
+
pipeline_name='fal-ai/recraft-v3',
|
| 272 |
+
generation_path=f'generation/recraftv3',
|
| 273 |
+
base_resolution=1024,
|
| 274 |
+
generation_kwargs={
|
| 275 |
+
"enable_safety_checker": False,
|
| 276 |
+
"num_images": 1,
|
| 277 |
+
# "aspect_ratio": "1:1",
|
| 278 |
+
"output_format": "jpeg",
|
| 279 |
+
"safety_tolerance": 5,
|
| 280 |
+
}
|
| 281 |
+
)
|
| 282 |
+
|
generate/utils/utils.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
try:
|
| 3 |
+
import fal_client
|
| 4 |
+
except:
|
| 5 |
+
fal_client = None
|
| 6 |
+
|
| 7 |
+
from diffusers import AutoPipelineForText2Image, HunyuanVideoPipeline, DiffusionPipeline
|
| 8 |
+
import json
|
| 9 |
+
import diffusers
|
| 10 |
+
from functools import partial
|
| 11 |
+
import os
|
| 12 |
+
# export FAL_KEY="YOUR_API_KEY"
|
| 13 |
+
os.environ['FAL_KEY'] = 'YOUR_API_KEY'
|
| 14 |
+
|
| 15 |
+
def init_multiple_pipelines(pipe_name, pipe_init_kwargs, num_devices, device_id=None):
|
| 16 |
+
pipelines_dict = []
|
| 17 |
+
|
| 18 |
+
if device_id is not None:
|
| 19 |
+
assert num_devices == 1
|
| 20 |
+
|
| 21 |
+
for i in range(num_devices):
|
| 22 |
+
actual_device_id = device_id if device_id is not None else i
|
| 23 |
+
try:
|
| 24 |
+
pipeline = AutoPipelineForText2Image.from_pretrained(pipe_name, **pipe_init_kwargs).to(f'cuda:{actual_device_id}')
|
| 25 |
+
except Exception as e:
|
| 26 |
+
# try:
|
| 27 |
+
config = json.load(open(os.path.join(pipe_name, 'model_index.json')))
|
| 28 |
+
class_name_str = config['_class_name']
|
| 29 |
+
pipeline_class = getattr(diffusers, class_name_str)
|
| 30 |
+
pipeline = pipeline_class.from_pretrained(pipe_name, **pipe_init_kwargs).to(f'cuda:{actual_device_id}')
|
| 31 |
+
# except Exception as ew:
|
| 32 |
+
# print(e)
|
| 33 |
+
# pipeline = DiffusionPipeline.from_pretrained(pipe_name, **pipe_init_kwargs).to(f'cuda:{actual_device_id}')
|
| 34 |
+
pipelines_dict.append(pipeline)
|
| 35 |
+
return pipelines_dict
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def init_pipeline_from_names(pipe_names, weight_dtype):
|
| 39 |
+
pipelines_dict = {}
|
| 40 |
+
for name in pipe_names:
|
| 41 |
+
pipeline = AutoPipelineForText2Image.from_pretrained(name, torch_dtype=weight_dtype)
|
| 42 |
+
pipelines_dict[name] = pipeline
|
| 43 |
+
return pipelines_dict
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def on_queue_update(update):
|
| 47 |
+
if isinstance(update, fal_client.InProgress):
|
| 48 |
+
for log in update.logs:
|
| 49 |
+
print(log["message"])
|
| 50 |
+
|
| 51 |
+
def gen_with_api(pipe_names, generation_kwargs):
|
| 52 |
+
result = fal_client.subscribe(
|
| 53 |
+
pipe_names,
|
| 54 |
+
arguments=generation_kwargs,
|
| 55 |
+
with_logs=True,
|
| 56 |
+
on_queue_update=on_queue_update,
|
| 57 |
+
)
|
| 58 |
+
return result
|
hpsv3/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .inference import HPSv3RewardInferencer
|
hpsv3/cohp/__init__.py
ADDED
|
File without changes
|
hpsv3/cohp/cohp_all.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from generator import Generator
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import torch
|
| 6 |
+
import gc
|
| 7 |
+
from utils_cohp.pipelines import *
|
| 8 |
+
from utils_cohp.image2image_pipeline import Image2ImagePipeline
|
| 9 |
+
import argparse
|
| 10 |
+
from ..inference import HPSv3RewardInferencer
|
| 11 |
+
import random
|
| 12 |
+
from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer
|
| 13 |
+
import ImageReward as RM
|
| 14 |
+
from PIL import Image
|
| 15 |
+
from transformers import AutoProcessor, AutoModel
|
| 16 |
+
|
| 17 |
+
def initialize_model(device, cp):
|
| 18 |
+
model_dict = {}
|
| 19 |
+
model, preprocess_train, preprocess_val = create_model_and_transforms(
|
| 20 |
+
'ViT-H-14',
|
| 21 |
+
'laion2B-s32B-b79K',
|
| 22 |
+
precision='amp',
|
| 23 |
+
device=device,
|
| 24 |
+
jit=False,
|
| 25 |
+
force_quick_gelu=False,
|
| 26 |
+
force_custom_text=False,
|
| 27 |
+
force_patch_dropout=False,
|
| 28 |
+
force_image_size=None,
|
| 29 |
+
pretrained_image=False,
|
| 30 |
+
image_mean=None,
|
| 31 |
+
image_std=None,
|
| 32 |
+
light_augmentation=True,
|
| 33 |
+
aug_cfg={},
|
| 34 |
+
output_dict=True,
|
| 35 |
+
with_score_predictor=False,
|
| 36 |
+
with_region_predictor=False
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
checkpoint = torch.load(cp, map_location=device, weights_only=False)
|
| 40 |
+
model.load_state_dict(checkpoint['state_dict'])
|
| 41 |
+
model = model.to(device)
|
| 42 |
+
model.eval()
|
| 43 |
+
tokenizer = get_tokenizer('ViT-H-14')
|
| 44 |
+
|
| 45 |
+
model_dict['model'] = model
|
| 46 |
+
model_dict['preprocess_val'] = preprocess_val
|
| 47 |
+
return model_dict, tokenizer
|
| 48 |
+
|
| 49 |
+
def score_hpsv2_batch(model_dict, tokenizer, device, img_paths: list, prompts: list) -> list:
|
| 50 |
+
model = model_dict['model']
|
| 51 |
+
preprocess_val = model_dict['preprocess_val']
|
| 52 |
+
# 批量处理图片
|
| 53 |
+
images = [preprocess_val(Image.open(p)).unsqueeze(0) for p in img_paths]
|
| 54 |
+
images = torch.cat(images, dim=0).to(device=device)
|
| 55 |
+
texts = tokenizer(prompts).to(device=device)
|
| 56 |
+
with torch.no_grad():
|
| 57 |
+
outputs = model(images, texts)
|
| 58 |
+
image_features, text_features = outputs["image_features"], outputs["text_features"]
|
| 59 |
+
logits_per_image = image_features @ text_features.T
|
| 60 |
+
hps_scores = torch.diagonal(logits_per_image).cpu()
|
| 61 |
+
return hps_scores
|
| 62 |
+
def pickscorecalc_probs(model,processor_pickscore,prompt, images, device):
|
| 63 |
+
|
| 64 |
+
# preprocess
|
| 65 |
+
image_inputs = processor_pickscore(
|
| 66 |
+
images=images,
|
| 67 |
+
padding=True,
|
| 68 |
+
truncation=True,
|
| 69 |
+
max_length=77,
|
| 70 |
+
return_tensors="pt",
|
| 71 |
+
).to(device)
|
| 72 |
+
|
| 73 |
+
text_inputs = processor_pickscore(
|
| 74 |
+
text=prompt,
|
| 75 |
+
padding=True,
|
| 76 |
+
truncation=True,
|
| 77 |
+
max_length=77,
|
| 78 |
+
return_tensors="pt",
|
| 79 |
+
).to(device)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
with torch.no_grad():
|
| 83 |
+
# embed
|
| 84 |
+
image_embs = model.get_image_features(**image_inputs)
|
| 85 |
+
image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
|
| 86 |
+
|
| 87 |
+
text_embs = model.get_text_features(**text_inputs)
|
| 88 |
+
text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)
|
| 89 |
+
|
| 90 |
+
# score
|
| 91 |
+
scores = text_embs @ image_embs.T
|
| 92 |
+
|
| 93 |
+
return scores
|
| 94 |
+
|
| 95 |
+
def generate_images(reward_type, prompt, index, pipeline_params, di_pipeline, inferencer, out_dir='cohp_output', num_rounds=5, strength=0.8, device='cuda:1'):
|
| 96 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 97 |
+
os.makedirs(os.path.join(out_dir, 'result_json'), exist_ok=True)
|
| 98 |
+
batch_size = 2 # 设置batch大小
|
| 99 |
+
|
| 100 |
+
results = [] # 用于保存每个 prompt 的最终结果
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
info_dict = {
|
| 104 |
+
'caption': prompt,
|
| 105 |
+
'width': 1024,
|
| 106 |
+
'height': 1024,
|
| 107 |
+
'aspect_ratio': 1,
|
| 108 |
+
'save_name': f"{index}_origin",
|
| 109 |
+
}
|
| 110 |
+
di_score_pipelines = {} # 用于存储 pipeline 的平均分数
|
| 111 |
+
|
| 112 |
+
# 中间结果记录结构:用于保存每一轮图像路径和分数
|
| 113 |
+
intermediate_results_sample_preference = []
|
| 114 |
+
intermediate_results_model_preference = []
|
| 115 |
+
|
| 116 |
+
# 遍历 pipeline 参数
|
| 117 |
+
for pipeline_param in pipeline_params:
|
| 118 |
+
|
| 119 |
+
name = di_pipeline[pipeline_param]
|
| 120 |
+
generator = Generator(
|
| 121 |
+
device = device,
|
| 122 |
+
pipe_name=pipeline_param.pipeline_name,
|
| 123 |
+
pipe_type=pipeline_param.pipeline_type,
|
| 124 |
+
pipe_init_kwargs=pipeline_param.pipe_init_kwargs,
|
| 125 |
+
)
|
| 126 |
+
image_paths = generator.generate_imgs(
|
| 127 |
+
info_dict = info_dict,
|
| 128 |
+
generation_path = os.path.join(out_dir, pipeline_param.generation_path),
|
| 129 |
+
batch_size=batch_size,
|
| 130 |
+
device = device,
|
| 131 |
+
seed=random.randint(0, 75859066837),
|
| 132 |
+
weight_dtype=pipeline_param.pipe_init_kwargs["torch_dtype"],
|
| 133 |
+
generation_kwargs=pipeline_param.generation_kwargs,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
# 对生成的图像进行评分
|
| 137 |
+
score_list = []
|
| 138 |
+
for image_path in image_paths:
|
| 139 |
+
if reward_type == 'hpsv2':
|
| 140 |
+
score = score_hpsv2_batch(model_dict, tokenizer, device, [image_path], [prompt])
|
| 141 |
+
score = score.item()
|
| 142 |
+
elif reward_type == 'hpsv3':
|
| 143 |
+
score = inferencer.reward([image_path], [prompt]).cpu().detach()
|
| 144 |
+
score = score[0][0].item()
|
| 145 |
+
elif reward_type == 'imagereward':
|
| 146 |
+
score = inferencer.score(prompt, [image_path])
|
| 147 |
+
elif reward_type == 'pickscore':
|
| 148 |
+
score = pickscorecalc_probs(inferencer, processor_pickscore, prompt, [Image.open(image_path)],device)[0][0].item()
|
| 149 |
+
print(f"PickScore for {image_path}: {score}")
|
| 150 |
+
else:
|
| 151 |
+
raise ValueError("Unsupported reward type. Choose 'hpsv2', 'hpsv3', or 'imagereward'.")
|
| 152 |
+
score_list.append(score)
|
| 153 |
+
|
| 154 |
+
average = sum(score_list) / len(score_list)
|
| 155 |
+
di_score_pipelines[name] = average
|
| 156 |
+
# 保存中间步骤的图像路径和分数
|
| 157 |
+
intermediate_results_model_preference.append({
|
| 158 |
+
'pipeline': name,
|
| 159 |
+
'image_paths': image_paths, # 所有生成的图片路径
|
| 160 |
+
'scores': score_list, # 每轮的得分列表
|
| 161 |
+
'max_image_path': image_paths[score_list.index(max(score_list))], # 当前轮得分最高的图片路径
|
| 162 |
+
'max_score': max(score_list) # 当前轮得分最高的分数
|
| 163 |
+
})
|
| 164 |
+
|
| 165 |
+
# 清理生成器资源
|
| 166 |
+
generator.pipelines.to("cpu")
|
| 167 |
+
del generator
|
| 168 |
+
torch.cuda.empty_cache()
|
| 169 |
+
gc.collect()
|
| 170 |
+
|
| 171 |
+
# 选择得分最高的 pipeline 和对应的图片
|
| 172 |
+
max_key = max(di_score_pipelines, key=di_score_pipelines.get)
|
| 173 |
+
max_index = score_list.index(max(score_list))
|
| 174 |
+
image_path_chosen = image_paths[max_index] # 首轮选择的最佳图片
|
| 175 |
+
|
| 176 |
+
# 多轮优化循环
|
| 177 |
+
|
| 178 |
+
for round_num in range(num_rounds):
|
| 179 |
+
if round_num == 3 or round_num == 4:
|
| 180 |
+
strength = 0.5
|
| 181 |
+
i2ipipeline = Image2ImagePipeline(max_key)
|
| 182 |
+
images = i2ipipeline.generate_image(
|
| 183 |
+
prompt=prompt,
|
| 184 |
+
image_path=image_path_chosen,
|
| 185 |
+
strength=strength,
|
| 186 |
+
batch_size=4,
|
| 187 |
+
save_prefix=f'{index}_{max_key}_image2image_round{round_num + 1}',
|
| 188 |
+
output_dir=out_dir
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
score_list = []
|
| 192 |
+
for image_path in images:
|
| 193 |
+
if reward_type == 'hpsv2':
|
| 194 |
+
score = score_hpsv2_batch(model_dict, tokenizer, device, [image_path], [prompt])
|
| 195 |
+
score = score.item()
|
| 196 |
+
elif reward_type == 'hpsv3':
|
| 197 |
+
score = inferencer.reward([image_path], [prompt]).cpu().detach()
|
| 198 |
+
score = score[0][0].item()
|
| 199 |
+
elif reward_type == 'imagereward':
|
| 200 |
+
score = inferencer.score(prompt, [image_path])
|
| 201 |
+
elif reward_type == 'pickscore':
|
| 202 |
+
score = pickscorecalc_probs(inferencer, processor_pickscore, prompt, [Image.open(image_path)],device)[0][0].item()
|
| 203 |
+
print(f"PickScore for {image_path}: {score}")
|
| 204 |
+
else:
|
| 205 |
+
raise ValueError("Unsupported reward type. Choose 'hpsv2', 'hpsv3', or 'imagereward'.")
|
| 206 |
+
score_list.append(score)
|
| 207 |
+
|
| 208 |
+
intermediate_results_sample_preference.append({
|
| 209 |
+
'round': round_num + 1,
|
| 210 |
+
'image_paths': images, # 所有生成的图片路径
|
| 211 |
+
'scores': score_list, # 每轮的得分列表
|
| 212 |
+
'max_image_path': images[score_list.index(max(score_list))], # 当前轮得分最高的图片路径
|
| 213 |
+
'max_score': max(score_list) # 当前轮得分最高的分数
|
| 214 |
+
})
|
| 215 |
+
|
| 216 |
+
# 更新图片选择
|
| 217 |
+
max_index = score_list.index(max(score_list))
|
| 218 |
+
image_path_chosen = images[max_index]
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
# 最终结果保存
|
| 223 |
+
results.append({
|
| 224 |
+
'prompt': prompt,
|
| 225 |
+
'model_preference_image_chosen': image_path_chosen,
|
| 226 |
+
"model_preference_info": intermediate_results_model_preference, # 包含所有中间结果
|
| 227 |
+
'best_image_path': image_path_chosen,
|
| 228 |
+
'best_model': max_key,
|
| 229 |
+
'score': max(score_list),
|
| 230 |
+
'sample_preference_intermediate_results': intermediate_results_sample_preference, # 包含所有中间结果
|
| 231 |
+
})
|
| 232 |
+
with open(os.path.join(out_dir, 'result_json',f'{index}.json'),'w',encoding='utf-8') as f:
|
| 233 |
+
json.dump(results,f,ensure_ascii=False, indent=4)
|
| 234 |
+
|
| 235 |
+
return results
|
| 236 |
+
|
| 237 |
+
if __name__ == "__main__":
|
| 238 |
+
parser = argparse.ArgumentParser(description="Image Generation Script")
|
| 239 |
+
parser.add_argument('--prompt', type=str, required=True, help='The prompt for image generation')
|
| 240 |
+
parser.add_argument('--index', type=str, required=True, help='Index for saving results')
|
| 241 |
+
parser.add_argument('--device', type=str, default='cuda:1', help='Device to run the model on')
|
| 242 |
+
parser.add_argument('--reward_model', type=str, default='hpsv3', help='Reward model to use (hpsv2 or hpsv3 or pickscore or imagereward)')
|
| 243 |
+
|
| 244 |
+
args = parser.parse_args()
|
| 245 |
+
output_dir = f"cohp_output_{args.reward_model}"
|
| 246 |
+
|
| 247 |
+
os.makedirs(output_dir,exist_ok=True)
|
| 248 |
+
if args.reward_model == 'hpsv2':
|
| 249 |
+
|
| 250 |
+
inferencer = initialize_model(args.device, 'pretrained_models/HPS_v2.1_compressed.pt')
|
| 251 |
+
model_dict, tokenizer = inferencer
|
| 252 |
+
elif args.reward_model == 'hpsv3':
|
| 253 |
+
dtype = torch.bfloat16
|
| 254 |
+
inferencer = HPSv3RewardInferencer(device=args.device, dtype=dtype)
|
| 255 |
+
elif args.reward_model == 'imagereward':
|
| 256 |
+
inferencer = RM.load("ImageReward-v1.0").to(args.device)
|
| 257 |
+
elif args.reward_model == 'pickscore':
|
| 258 |
+
processor_name_or_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
|
| 259 |
+
model_pretrained_name_or_path = "yuvalkirstain/PickScore_v1"
|
| 260 |
+
processor_pickscore = AutoProcessor.from_pretrained(processor_name_or_path)
|
| 261 |
+
inferencer = AutoModel.from_pretrained(model_pretrained_name_or_path).eval().to(args.device)
|
| 262 |
+
else:
|
| 263 |
+
raise ValueError("Unsupported reward model. Choose 'hpsv2', 'hpsv3', or 'imagereward'.")
|
| 264 |
+
pipeline_params = [
|
| 265 |
+
flux_dev_pipe,
|
| 266 |
+
kolors_pipe,
|
| 267 |
+
sd3_medium_pipe,
|
| 268 |
+
playground_v2_5_pipe
|
| 269 |
+
]
|
| 270 |
+
|
| 271 |
+
di_score_pipelines={}
|
| 272 |
+
di_pipeline = {
|
| 273 |
+
flux_dev_pipe:'flux',
|
| 274 |
+
kolors_pipe:'kolors',
|
| 275 |
+
sd3_medium_pipe:'sd3',
|
| 276 |
+
playground_v2_5_pipe:'playground_v2_5'
|
| 277 |
+
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
results = generate_images(
|
| 281 |
+
args.reward_model,
|
| 282 |
+
args.prompt,
|
| 283 |
+
args.index,
|
| 284 |
+
pipeline_params,
|
| 285 |
+
di_pipeline,
|
| 286 |
+
inferencer,
|
| 287 |
+
out_dir=output_dir,
|
| 288 |
+
num_rounds=4,
|
| 289 |
+
strength=0.8,
|
| 290 |
+
device=args.device)
|
hpsv3/cohp/generator.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import os
|
| 3 |
+
import inspect
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
from utils_cohp.utils import init_pipelines
|
| 7 |
+
|
| 8 |
+
Image.MAX_IMAGE_PIXELS = None
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Generator:
|
| 12 |
+
def __init__(
|
| 13 |
+
self, pipe_name, pipe_type, pipe_init_kwargs, device=None
|
| 14 |
+
):
|
| 15 |
+
self.pipe_names = pipe_name
|
| 16 |
+
self.pipe_type = pipe_type
|
| 17 |
+
self.pipe_init_kwargs = pipe_init_kwargs
|
| 18 |
+
self.pipelines = init_pipelines(
|
| 19 |
+
pipe_name, pipe_init_kwargs, device
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
def generate_imgs(
|
| 23 |
+
self,
|
| 24 |
+
batch_size,
|
| 25 |
+
generation_path,
|
| 26 |
+
info_dict,
|
| 27 |
+
device,
|
| 28 |
+
weight_dtype,
|
| 29 |
+
seed,
|
| 30 |
+
generation_kwargs,
|
| 31 |
+
|
| 32 |
+
):
|
| 33 |
+
|
| 34 |
+
torch.cuda.set_device(device)
|
| 35 |
+
device = torch.device(device)
|
| 36 |
+
generator = torch.Generator().manual_seed(seed)
|
| 37 |
+
|
| 38 |
+
pipeline_signature = inspect.signature(self.pipelines)
|
| 39 |
+
pipeline_params = pipeline_signature.parameters.keys()
|
| 40 |
+
|
| 41 |
+
if 'height' not in pipeline_params:
|
| 42 |
+
generation_kwargs.pop('height', None)
|
| 43 |
+
print(f"Warning: Pipeline does not support 'height' parameter, removing from kwargs")
|
| 44 |
+
if 'width' not in pipeline_params:
|
| 45 |
+
generation_kwargs.pop('width', None)
|
| 46 |
+
print(f"Warning: Pipeline does not support 'width' parameter, removing from kwargs")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
outputs = self.pipelines(
|
| 50 |
+
prompt=info_dict['caption'], generator=generator,num_images_per_prompt = batch_size, **generation_kwargs
|
| 51 |
+
)
|
| 52 |
+
if self.pipe_type == "t2i":
|
| 53 |
+
images = outputs.images
|
| 54 |
+
elif self.pipe_type == "t2v":
|
| 55 |
+
images = outputs.frames[0]
|
| 56 |
+
image_paths = []
|
| 57 |
+
for idx, image in enumerate(images):
|
| 58 |
+
img_path = os.path.join(
|
| 59 |
+
generation_path, info_dict["save_name"] + f"_{idx}.png"
|
| 60 |
+
)
|
| 61 |
+
os.makedirs(generation_path,exist_ok=True)
|
| 62 |
+
image.save(img_path)
|
| 63 |
+
image_paths.append(img_path)
|
| 64 |
+
return image_paths
|
hpsv3/cohp/run_cohp.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import random
|
| 4 |
+
import gc
|
| 5 |
+
import argparse
|
| 6 |
+
import torch
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from transformers import AutoProcessor, AutoModel
|
| 9 |
+
|
| 10 |
+
from generator import Generator
|
| 11 |
+
from hpsv3.inference import HPSv3RewardInferencer
|
| 12 |
+
from hpsv3.cohp.utils_cohp.pipelines import *
|
| 13 |
+
from hpsv3.cohp.utils_cohp.image2image_pipeline import Image2ImagePipeline
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer
|
| 17 |
+
except:
|
| 18 |
+
print("HPSv2 model not found, skipping HPSv2 related imports.")
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
import ImageReward as RM
|
| 22 |
+
except:
|
| 23 |
+
print("ImageReward module not found, skipping ImageReward related imports.")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def initialize_hpsv2_model(device, checkpoint_path):
|
| 27 |
+
model_dict = {}
|
| 28 |
+
model, _, preprocess_val = create_model_and_transforms(
|
| 29 |
+
'ViT-H-14',
|
| 30 |
+
'laion2B-s32B-b79K',
|
| 31 |
+
device=device,
|
| 32 |
+
precision='amp',
|
| 33 |
+
pretrained_image=False,
|
| 34 |
+
output_dict=True,
|
| 35 |
+
)
|
| 36 |
+
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
| 37 |
+
model.load_state_dict(checkpoint['state_dict'])
|
| 38 |
+
model = model.to(device).eval()
|
| 39 |
+
tokenizer = get_tokenizer('ViT-H-14')
|
| 40 |
+
|
| 41 |
+
model_dict['model'] = model
|
| 42 |
+
model_dict['preprocess_val'] = preprocess_val
|
| 43 |
+
return model_dict, tokenizer
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def score_hpsv2(model_dict, tokenizer, device, img_paths, prompts):
|
| 47 |
+
model = model_dict['model']
|
| 48 |
+
preprocess_val = model_dict['preprocess_val']
|
| 49 |
+
images = [preprocess_val(Image.open(p)).unsqueeze(0) for p in img_paths]
|
| 50 |
+
images = torch.cat(images, dim=0).to(device)
|
| 51 |
+
texts = tokenizer(prompts).to(device)
|
| 52 |
+
|
| 53 |
+
with torch.no_grad():
|
| 54 |
+
outputs = model(images, texts)
|
| 55 |
+
image_features, text_features = outputs["image_features"], outputs["text_features"]
|
| 56 |
+
logits_per_image = image_features @ text_features.T
|
| 57 |
+
hps_scores = torch.diagonal(logits_per_image).cpu()
|
| 58 |
+
return hps_scores
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def calculate_pickscore_probs(model, processor, prompt, images, device):
|
| 62 |
+
image_inputs = processor(images=images, padding=True, return_tensors="pt").to(device)
|
| 63 |
+
text_inputs = processor(text=prompt, padding=True, return_tensors="pt").to(device)
|
| 64 |
+
|
| 65 |
+
with torch.no_grad():
|
| 66 |
+
image_embs = model.get_image_features(**image_inputs)
|
| 67 |
+
image_embs /= torch.norm(image_embs, dim=-1, keepdim=True)
|
| 68 |
+
|
| 69 |
+
text_embs = model.get_text_features(**text_inputs)
|
| 70 |
+
text_embs /= torch.norm(text_embs, dim=-1, keepdim=True)
|
| 71 |
+
|
| 72 |
+
scores = text_embs @ image_embs.T
|
| 73 |
+
return scores
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def generate_images(
|
| 77 |
+
reward_type, prompt, index, pipeline_params, pipelines_mapping, inferencer,
|
| 78 |
+
output_dir='cohp_output', num_rounds=5, strength=0.8, device='cuda:1'
|
| 79 |
+
):
|
| 80 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 81 |
+
result_json_dir = os.path.join(output_dir, 'result_json')
|
| 82 |
+
os.makedirs(result_json_dir, exist_ok=True)
|
| 83 |
+
|
| 84 |
+
info_dict = {
|
| 85 |
+
'caption': prompt,
|
| 86 |
+
'width': 1024,
|
| 87 |
+
'height': 1024,
|
| 88 |
+
'aspect_ratio': 1,
|
| 89 |
+
'save_name': f"{index}_origin",
|
| 90 |
+
}
|
| 91 |
+
di_score_pipelines = {}
|
| 92 |
+
intermediate_results_model_pref = {}
|
| 93 |
+
intermediate_results_sample_pref = {}
|
| 94 |
+
max_final_score = 0
|
| 95 |
+
|
| 96 |
+
for pipeline_param in pipeline_params:
|
| 97 |
+
generator = Generator(
|
| 98 |
+
device=device,
|
| 99 |
+
pipe_name=pipeline_param.pipeline_name,
|
| 100 |
+
pipe_type=pipeline_param.pipeline_type,
|
| 101 |
+
pipe_init_kwargs=pipeline_param.pipe_init_kwargs,
|
| 102 |
+
)
|
| 103 |
+
image_paths = generator.generate_imgs(
|
| 104 |
+
info_dict=info_dict,
|
| 105 |
+
generation_path=os.path.join(output_dir, pipeline_param.generation_path),
|
| 106 |
+
batch_size=2,
|
| 107 |
+
device=device,
|
| 108 |
+
seed=random.randint(0, 75859066837),
|
| 109 |
+
weight_dtype=pipeline_param.pipe_init_kwargs["torch_dtype"],
|
| 110 |
+
generation_kwargs=pipeline_param.generation_kwargs
|
| 111 |
+
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
score_list = []
|
| 115 |
+
for image_path in image_paths:
|
| 116 |
+
if reward_type == 'hpsv2':
|
| 117 |
+
score = score_hpsv2(model_dict, tokenizer, device, [image_path], [prompt]).item()
|
| 118 |
+
elif reward_type == 'hpsv3':
|
| 119 |
+
score = inferencer.reward([image_path], [prompt]).cpu().detach()[0][0].item()
|
| 120 |
+
elif reward_type == 'imagereward':
|
| 121 |
+
score = inferencer.score(prompt, [image_path])
|
| 122 |
+
elif reward_type == 'pickscore':
|
| 123 |
+
score = calculate_pickscore_probs(inferencer, processor_pickscore, prompt, [Image.open(image_path)], device)[0][0].item()
|
| 124 |
+
else:
|
| 125 |
+
raise ValueError(f"Unsupported reward type: {reward_type}")
|
| 126 |
+
score_list.append(score)
|
| 127 |
+
|
| 128 |
+
average_score = sum(score_list) / len(score_list)
|
| 129 |
+
pipeline_name = pipelines_mapping[pipeline_param]
|
| 130 |
+
di_score_pipelines[pipeline_name] = average_score
|
| 131 |
+
|
| 132 |
+
intermediate_results_model_pref[pipeline_name] = {
|
| 133 |
+
'image_paths': image_paths,
|
| 134 |
+
'scores': score_list,
|
| 135 |
+
'max_image_path': image_paths[score_list.index(max(score_list))],
|
| 136 |
+
'max_score': max(score_list),
|
| 137 |
+
}
|
| 138 |
+
generator.pipelines.to("cpu")
|
| 139 |
+
del generator
|
| 140 |
+
torch.cuda.empty_cache()
|
| 141 |
+
gc.collect()
|
| 142 |
+
|
| 143 |
+
# Select the best pipeline based on scores
|
| 144 |
+
best_pipeline = max(di_score_pipelines, key=di_score_pipelines.get)
|
| 145 |
+
best_pipeline_results = intermediate_results_model_pref[best_pipeline]
|
| 146 |
+
chosen_image_path = best_pipeline_results['max_image_path']
|
| 147 |
+
|
| 148 |
+
# Refinement with Image2ImagePipeline
|
| 149 |
+
i2ipipeline = Image2ImagePipeline(best_pipeline)
|
| 150 |
+
for round_num in range(num_rounds):
|
| 151 |
+
if round_num in [3, 4]:
|
| 152 |
+
strength = 0.5
|
| 153 |
+
images = i2ipipeline.generate_image(
|
| 154 |
+
prompt=prompt,
|
| 155 |
+
image_path=chosen_image_path,
|
| 156 |
+
strength=strength,
|
| 157 |
+
batch_size=4,
|
| 158 |
+
save_prefix=f'{index}_{best_pipeline}_image2image_round{round_num + 1}',
|
| 159 |
+
output_dir=output_dir,
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
score_list = []
|
| 163 |
+
for image_path in images:
|
| 164 |
+
if reward_type == 'hpsv2':
|
| 165 |
+
score = score_hpsv2(model_dict, tokenizer, device, [image_path], [prompt]).item()
|
| 166 |
+
elif reward_type == 'hpsv3':
|
| 167 |
+
score = inferencer.reward([image_path], [prompt]).cpu().detach()[0][0].item()
|
| 168 |
+
elif reward_type == 'imagereward':
|
| 169 |
+
score = inferencer.score(prompt, [image_path])
|
| 170 |
+
elif reward_type == 'pickscore':
|
| 171 |
+
score = calculate_pickscore_probs(inferencer, processor_pickscore, prompt, [Image.open(image_path)], device)[0][0].item()
|
| 172 |
+
else:
|
| 173 |
+
raise ValueError(f"Unsupported reward type: {reward_type}")
|
| 174 |
+
score_list.append(score)
|
| 175 |
+
|
| 176 |
+
# Update intermediate results
|
| 177 |
+
intermediate_results_sample_pref[round_num + 1] = {
|
| 178 |
+
'image_paths': images,
|
| 179 |
+
'scores': score_list,
|
| 180 |
+
'max_image_path': images[score_list.index(max(score_list))],
|
| 181 |
+
'max_score': max(score_list),
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
# Determine best image during refinement
|
| 185 |
+
if max(score_list) > max_final_score:
|
| 186 |
+
max_final_score = max(score_list)
|
| 187 |
+
chosen_image_path = images[score_list.index(max(score_list))]
|
| 188 |
+
|
| 189 |
+
# Save final results
|
| 190 |
+
results = {
|
| 191 |
+
'prompt': prompt,
|
| 192 |
+
'best_model': best_pipeline,
|
| 193 |
+
'final_image_path': chosen_image_path,
|
| 194 |
+
'model_preference_info': intermediate_results_model_pref,
|
| 195 |
+
'sample_preference_intermediate_results': intermediate_results_sample_pref,
|
| 196 |
+
}
|
| 197 |
+
with open(os.path.join(result_json_dir, f'{index}.json'), 'w', encoding='utf-8') as file:
|
| 198 |
+
json.dump(results, file, ensure_ascii=False, indent=4)
|
| 199 |
+
return results
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
if __name__ == "__main__":
|
| 203 |
+
parser = argparse.ArgumentParser(description="Image Generation Script")
|
| 204 |
+
parser.add_argument('--prompt', type=str, required=True, help='The prompt for image generation')
|
| 205 |
+
parser.add_argument('--index', type=str, required=True, help='Index for saving results')
|
| 206 |
+
parser.add_argument('--device', type=str, default='cuda:1', help='Device to run the model on')
|
| 207 |
+
parser.add_argument('--reward_model', type=str, default='hpsv3', help='Reward model to use (hpsv2, hpsv3, pickscore, or imagereward)')
|
| 208 |
+
args = parser.parse_args()
|
| 209 |
+
|
| 210 |
+
# Initialize models and pipelines
|
| 211 |
+
output_dir = f"cohp_output_{args.reward_model}"
|
| 212 |
+
if args.reward_model == 'hpsv2':
|
| 213 |
+
model_dict, tokenizer = initialize_hpsv2_model(args.device, 'pretrained_models/HPS_v2.1_compressed.pt')
|
| 214 |
+
inferencer = model_dict
|
| 215 |
+
elif args.reward_model == 'hpsv3':
|
| 216 |
+
inferencer = HPSv3RewardInferencer(device=args.device)
|
| 217 |
+
elif args.reward_model == 'imagereward':
|
| 218 |
+
inferencer = RM.load("ImageReward-v1.0").to(args.device)
|
| 219 |
+
elif args.reward_model == 'pickscore':
|
| 220 |
+
processor_pickscore = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
|
| 221 |
+
inferencer = AutoModel.from_pretrained("yuvalkirstain/PickScore_v1").eval().to(args.device)
|
| 222 |
+
else:
|
| 223 |
+
raise ValueError("Unsupported reward model.")
|
| 224 |
+
|
| 225 |
+
# Define pipelines
|
| 226 |
+
pipeline_params = [kolors_pipe, sd3_medium_pipe, playground_v2_5_pipe, flux_dev_pipe]
|
| 227 |
+
pipelines_mapping = {
|
| 228 |
+
flux_dev_pipe: 'flux',
|
| 229 |
+
kolors_pipe: 'kolors',
|
| 230 |
+
sd3_medium_pipe: 'sd3',
|
| 231 |
+
playground_v2_5_pipe: 'playground_v2_5',
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
# Generate images
|
| 235 |
+
results = generate_images(
|
| 236 |
+
reward_type=args.reward_model,
|
| 237 |
+
prompt=args.prompt,
|
| 238 |
+
index=args.index,
|
| 239 |
+
pipeline_params=pipeline_params,
|
| 240 |
+
pipelines_mapping=pipelines_mapping,
|
| 241 |
+
inferencer=inferencer,
|
| 242 |
+
output_dir=output_dir,
|
| 243 |
+
num_rounds=4,
|
| 244 |
+
)
|
hpsv3/cohp/utils_cohp/__init__.py
ADDED
|
File without changes
|
hpsv3/cohp/utils_cohp/image2image_pipeline.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from diffusers import FluxImg2ImgPipeline, KolorsImg2ImgPipeline, StableDiffusion3Img2ImgPipeline, StableDiffusionXLImg2ImgPipeline
|
| 2 |
+
from diffusers.utils import load_image
|
| 3 |
+
import torch
|
| 4 |
+
import os
|
| 5 |
+
class Image2ImagePipeline:
|
| 6 |
+
def __init__(
|
| 7 |
+
self, pipe_name, device='cuda'
|
| 8 |
+
):
|
| 9 |
+
self.pipe_name = pipe_name
|
| 10 |
+
if self.pipe_name == 'flux':
|
| 11 |
+
self.pipeline = FluxImg2ImgPipeline.from_pretrained("pretrained_models/FLUX.1-dev",torch_dtype=torch.bfloat16).to(device)
|
| 12 |
+
self.generation_path = 'generation/flux_dev',
|
| 13 |
+
elif self.pipe_name == 'kolors':
|
| 14 |
+
self.pipeline = KolorsImg2ImgPipeline.from_pretrained("/preflab/shuiyunhao/tasks/HPSv3/pretrained_models/kolors",torch_dtype=torch.bfloat16).to(device)
|
| 15 |
+
self.generation_path = 'generation/kolors',
|
| 16 |
+
|
| 17 |
+
elif self.pipe_name == 'sd3':
|
| 18 |
+
self.pipeline = StableDiffusion3Img2ImgPipeline.from_pretrained("stabilityai/stable-diffusion-3.5-medium",torch_dtype=torch.bfloat16).to(device)
|
| 19 |
+
self.generation_path = 'generation/sd3_medium',
|
| 20 |
+
elif self.pipe_name == 'playground_v2_5':
|
| 21 |
+
self.pipeline = StableDiffusionXLImg2ImgPipeline.from_pretrained("pretrained_models/playground-v2.5-1024px-aesthetic",torch_dtype=torch.bfloat16).to(device)
|
| 22 |
+
self.generation_path = 'generation/playground_v_2_5',
|
| 23 |
+
self.pipeline = self.pipeline.to(torch.bfloat16)
|
| 24 |
+
def generate_image(
|
| 25 |
+
self,
|
| 26 |
+
prompt,
|
| 27 |
+
image_path,
|
| 28 |
+
strength,
|
| 29 |
+
batch_size,
|
| 30 |
+
save_prefix,
|
| 31 |
+
output_dir
|
| 32 |
+
):
|
| 33 |
+
image_load = load_image(image_path)
|
| 34 |
+
if self.pipe_name == 'flux':
|
| 35 |
+
images = self.pipeline(
|
| 36 |
+
prompt = prompt,
|
| 37 |
+
image=image_load,
|
| 38 |
+
num_images_per_prompt=batch_size,
|
| 39 |
+
strength = strength).images
|
| 40 |
+
else:
|
| 41 |
+
|
| 42 |
+
images = self.pipeline(
|
| 43 |
+
prompt = prompt,
|
| 44 |
+
negative_prompt = '',
|
| 45 |
+
image=image_load,
|
| 46 |
+
num_images_per_prompt=batch_size,
|
| 47 |
+
strength = strength).images
|
| 48 |
+
image_list = []
|
| 49 |
+
for ind,img in enumerate(images):
|
| 50 |
+
print(output_dir,self.generation_path,save_prefix)
|
| 51 |
+
save_path = os.path.join(output_dir,self.generation_path[0],save_prefix+f'_{ind}.png')
|
| 52 |
+
image_list.append(save_path)
|
| 53 |
+
img.save(save_path)
|
| 54 |
+
print(image_list)
|
| 55 |
+
return image_list
|
| 56 |
+
|
| 57 |
+
# pipeline = StableDiffusion3Img2ImgPipeline.from_pretrained("/preflab/shuiyunhao/tasks/HPSv3/pretrained_models/stable-diffusion-3-medium-diffusers",torch_dtype=torch.bfloat16).to('cuda:0')
|
| 58 |
+
# pipeline = pipeline.to(torch.bfloat16)
|
| 59 |
+
# image_load = load_image('/preflab/shuiyunhao/tasks/HPSv3/cohp_output/generation/flux_dev/0_origin_0.png')
|
| 60 |
+
# images = pipeline(
|
| 61 |
+
# prompt = 'a girl',
|
| 62 |
+
# negative_prompt = '',
|
| 63 |
+
# image=image_load,
|
| 64 |
+
# num_images_per_prompt=1,
|
| 65 |
+
# strength = 0.8).images
|
hpsv3/cohp/utils_cohp/pipelines.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
class PipelineParam:
|
| 3 |
+
pipeline_name: str
|
| 4 |
+
pipeline_type: str
|
| 5 |
+
generation_path: str
|
| 6 |
+
pipe_init_kwargs: dict
|
| 7 |
+
generation_kwargs: dict
|
| 8 |
+
base_resolution: int
|
| 9 |
+
force_aspect_ratio: int
|
| 10 |
+
|
| 11 |
+
def __init__(self, pipeline_name: str, generation_path: str, pipeline_type = 't2i',
|
| 12 |
+
pipe_init_kwargs: dict = None, generation_kwargs: dict = None,
|
| 13 |
+
base_resolution: int = 1024, force_aspect_ratio: int = None):
|
| 14 |
+
self.pipeline_name = pipeline_name
|
| 15 |
+
self.pipeline_type = pipeline_type
|
| 16 |
+
self.generation_path = generation_path
|
| 17 |
+
self.pipe_init_kwargs = pipe_init_kwargs if pipe_init_kwargs is not None else {}
|
| 18 |
+
self.generation_kwargs = generation_kwargs if generation_kwargs is not None else {}
|
| 19 |
+
self.base_resolution = base_resolution
|
| 20 |
+
self.force_aspect_ratio = force_aspect_ratio
|
| 21 |
+
|
| 22 |
+
flux_dev_pipe = PipelineParam(
|
| 23 |
+
pipeline_name='pretrained_models/FLUX.1-dev',
|
| 24 |
+
generation_path=f'generation/flux_dev',
|
| 25 |
+
pipe_init_kwargs={
|
| 26 |
+
"torch_dtype": torch.bfloat16,
|
| 27 |
+
},
|
| 28 |
+
base_resolution=1024,
|
| 29 |
+
generation_kwargs={
|
| 30 |
+
"guidance_scale": 3.5,
|
| 31 |
+
"num_inference_steps": 28,
|
| 32 |
+
"max_sequence_length": 512,
|
| 33 |
+
}
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
flux_schnell_pipe = PipelineParam(
|
| 37 |
+
pipeline_name='/mnt2/share/huggingface_models/FLUX.1-schnell',
|
| 38 |
+
generation_path=f'generation/flux_schnell',
|
| 39 |
+
pipe_init_kwargs={
|
| 40 |
+
"torch_dtype": torch.bfloat16,
|
| 41 |
+
},
|
| 42 |
+
base_resolution=1024,
|
| 43 |
+
generation_kwargs={
|
| 44 |
+
"guidance_scale": 3.5,
|
| 45 |
+
"num_inference_steps": 4,
|
| 46 |
+
}
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
sd3_medium_pipe = PipelineParam(
|
| 51 |
+
pipeline_name='pretrained_models/stable-diffusion-3-medium-diffusers',
|
| 52 |
+
generation_path=f'generation/sd3_medium',
|
| 53 |
+
pipe_init_kwargs={
|
| 54 |
+
"torch_dtype": torch.float16,
|
| 55 |
+
},
|
| 56 |
+
base_resolution=1024,
|
| 57 |
+
generation_kwargs={
|
| 58 |
+
"guidance_scale": 7.0,
|
| 59 |
+
"num_inference_steps": 28,
|
| 60 |
+
}
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
sd_xl_pipe = PipelineParam(
|
| 64 |
+
pipeline_name='pretrained_models/stable-diffusion-xl-base-1.0',
|
| 65 |
+
generation_path=f'generation/sd_xl',
|
| 66 |
+
pipe_init_kwargs={
|
| 67 |
+
"torch_dtype": torch.float16,
|
| 68 |
+
},
|
| 69 |
+
base_resolution=1024,
|
| 70 |
+
generation_kwargs={
|
| 71 |
+
"guidance_scale": 5,
|
| 72 |
+
"num_inference_steps": 50,
|
| 73 |
+
}
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
sd_1_5_pipe = PipelineParam(
|
| 77 |
+
pipeline_name='pretrained_models/stable-diffusion-v1-5',
|
| 78 |
+
generation_path=f'generation/sd_1_5',
|
| 79 |
+
pipe_init_kwargs={
|
| 80 |
+
"torch_dtype": torch.float16,
|
| 81 |
+
},
|
| 82 |
+
base_resolution=512,
|
| 83 |
+
generation_kwargs={
|
| 84 |
+
}
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
vq_diffusion_pipe = PipelineParam(
|
| 88 |
+
pipeline_name='pretrained_models/vq-diffusion-ithq',
|
| 89 |
+
generation_path=f'generation/vq_diffusion',
|
| 90 |
+
pipe_init_kwargs={
|
| 91 |
+
"torch_dtype": torch.float16,
|
| 92 |
+
},
|
| 93 |
+
base_resolution=256,
|
| 94 |
+
generation_kwargs={}
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
sd_2_pipe = PipelineParam(
|
| 98 |
+
pipeline_name='pretrained_models/stable-diffusion-2',
|
| 99 |
+
generation_path=f'generation/sd_2',
|
| 100 |
+
pipe_init_kwargs={
|
| 101 |
+
"torch_dtype": torch.float16,
|
| 102 |
+
},
|
| 103 |
+
base_resolution=512,
|
| 104 |
+
force_aspect_ratio=1,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
sd_1_1_pipe = PipelineParam(
|
| 108 |
+
pipeline_name='pretrained_models/stable-diffusion-v1-1',
|
| 109 |
+
generation_path=f'generation/sd_1_1',
|
| 110 |
+
pipe_init_kwargs={"torch_dtype": torch.float16,},
|
| 111 |
+
base_resolution=512,
|
| 112 |
+
force_aspect_ratio=1,
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
sd_1_4_pipe = PipelineParam(
|
| 116 |
+
pipeline_name='pretrained_models/stable-diffusion-v1-4',
|
| 117 |
+
generation_path=f'generation/sd_1_4',
|
| 118 |
+
pipe_init_kwargs={
|
| 119 |
+
"torch_dtype": torch.float16,
|
| 120 |
+
},
|
| 121 |
+
base_resolution=512,
|
| 122 |
+
force_aspect_ratio=1,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
sd_2_1_pipe = PipelineParam(
|
| 126 |
+
pipeline_name='pretrained_models/stable-diffusion-2-1-base',
|
| 127 |
+
generation_path=f'generation/sd_2_1',
|
| 128 |
+
pipe_init_kwargs={
|
| 129 |
+
"torch_dtype": torch.float16,
|
| 130 |
+
},
|
| 131 |
+
base_resolution=512,
|
| 132 |
+
force_aspect_ratio=1,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
openjourney_pipe = PipelineParam(
|
| 136 |
+
pipeline_name='pretrained_models/openjourney',
|
| 137 |
+
generation_path=f'generation/openjourney',
|
| 138 |
+
pipe_init_kwargs={
|
| 139 |
+
"torch_dtype": torch.float16,
|
| 140 |
+
},
|
| 141 |
+
base_resolution=512,
|
| 142 |
+
force_aspect_ratio=1,
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
playground_v2_5_pipe = PipelineParam(
|
| 146 |
+
pipeline_name='pretrained_models/playground-v2.5-1024px-aesthetic',
|
| 147 |
+
generation_path=f'generation/playground_v_2_5',
|
| 148 |
+
pipe_init_kwargs={
|
| 149 |
+
"torch_dtype": torch.float16,
|
| 150 |
+
},
|
| 151 |
+
base_resolution=1024,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
versatile_pipe = PipelineParam(
|
| 155 |
+
pipeline_name='pretrained_models/versatile-diffusion',
|
| 156 |
+
generation_path=f'generation/versatile',
|
| 157 |
+
pipe_init_kwargs={
|
| 158 |
+
"torch_dtype": torch.float16,
|
| 159 |
+
},
|
| 160 |
+
base_resolution=512,
|
| 161 |
+
force_aspect_ratio=1,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
glide_pipe = PipelineParam(
|
| 165 |
+
pipeline_name='pretrained_models/glide-base',
|
| 166 |
+
generation_path=f'generation/glide',
|
| 167 |
+
pipe_init_kwargs={
|
| 168 |
+
"torch_dtype": torch.float16,
|
| 169 |
+
},
|
| 170 |
+
base_resolution=512,
|
| 171 |
+
force_aspect_ratio=1,
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
sd_3_5_medium_pipe = PipelineParam(
|
| 175 |
+
pipeline_name='stabilityai/stable-diffusion-3.5-medium',
|
| 176 |
+
generation_path=f'generation/sd_3_5_medium',
|
| 177 |
+
pipe_init_kwargs={
|
| 178 |
+
"torch_dtype": torch.bfloat16,
|
| 179 |
+
},
|
| 180 |
+
base_resolution=1024,
|
| 181 |
+
generation_kwargs={
|
| 182 |
+
"num_inference_steps": 40,
|
| 183 |
+
"guidance_scale": 4.5,
|
| 184 |
+
}
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
sd_3_5_large_pipe = PipelineParam(
|
| 188 |
+
pipeline_name='stabilityai/stable-diffusion-3.5-large',
|
| 189 |
+
generation_path=f'generation/sd_3_5_large',
|
| 190 |
+
pipe_init_kwargs={
|
| 191 |
+
"torch_dtype": torch.bfloat16,
|
| 192 |
+
},
|
| 193 |
+
base_resolution=1024,
|
| 194 |
+
generation_kwargs={
|
| 195 |
+
"num_inference_steps": 28,
|
| 196 |
+
"guidance_scale": 3.5,
|
| 197 |
+
}
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
kolors_pipe = PipelineParam(
|
| 201 |
+
pipeline_name='pretrained_models/Kolors-diffusers',
|
| 202 |
+
generation_path=f'generation/kolors',
|
| 203 |
+
pipe_init_kwargs={
|
| 204 |
+
"torch_dtype": torch.float16,
|
| 205 |
+
'variant': 'fp16',
|
| 206 |
+
},
|
| 207 |
+
base_resolution=1024,
|
| 208 |
+
generation_kwargs={
|
| 209 |
+
"num_inference_steps": 50,
|
| 210 |
+
"guidance_scale": 5.0,
|
| 211 |
+
}
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
cogview4_pipe = PipelineParam(
|
| 215 |
+
pipeline_name='pretrained_models/CogView4-6B',
|
| 216 |
+
generation_path=f'generation/cogview4',
|
| 217 |
+
pipe_init_kwargs={
|
| 218 |
+
"torch_dtype": torch.bfloat16,
|
| 219 |
+
},
|
| 220 |
+
base_resolution=1024,
|
| 221 |
+
generation_kwargs={
|
| 222 |
+
"num_inference_steps": 50,
|
| 223 |
+
"guidance_scale": 3.5,
|
| 224 |
+
}
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
pixart_sigma_pipe = PipelineParam(
|
| 228 |
+
pipeline_name='pretrained_models/PixArt-Sigma-XL-2-1024-MS',
|
| 229 |
+
generation_path=f'generation/pixart_sigma',
|
| 230 |
+
pipeline_type='t2i',
|
| 231 |
+
pipe_init_kwargs={
|
| 232 |
+
"torch_dtype": torch.bfloat16,
|
| 233 |
+
},
|
| 234 |
+
base_resolution=1024,
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
hunyuanvideo_pipe = PipelineParam(
|
| 238 |
+
pipeline_name='pretrained_models/hunyuanvideo_diffusers',
|
| 239 |
+
generation_path=f'generation/hunyuanvideo',
|
| 240 |
+
pipe_init_kwargs={
|
| 241 |
+
"torch_dtype": torch.bfloat16,
|
| 242 |
+
},
|
| 243 |
+
base_resolution=1024,
|
| 244 |
+
pipeline_type='t2v',
|
| 245 |
+
generation_kwargs={
|
| 246 |
+
"num_inference_steps": 30,
|
| 247 |
+
"num_frames": 1,
|
| 248 |
+
}
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
hunyuandit_pipe = PipelineParam(
|
| 252 |
+
pipeline_name='pretrained_models/HunyuanDiT-v1.2-Diffusers',
|
| 253 |
+
generation_path=f'generation/hunyuandit',
|
| 254 |
+
pipe_init_kwargs={
|
| 255 |
+
"torch_dtype": torch.float16,
|
| 256 |
+
},
|
| 257 |
+
base_resolution=1024,
|
| 258 |
+
pipeline_type='t2i',
|
| 259 |
+
generation_kwargs={
|
| 260 |
+
}
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
# API models
|
| 264 |
+
# Fal.ai
|
| 265 |
+
flux_pro_v1_1_ultr_pipe = PipelineParam(
|
| 266 |
+
pipeline_name='fal-ai/flux-pro/v1.1-ultra',
|
| 267 |
+
generation_path=f'generation/flux_pro_v1_1_ultra',
|
| 268 |
+
base_resolution=1024,
|
| 269 |
+
generation_kwargs={
|
| 270 |
+
"enable_safety_checker": False,
|
| 271 |
+
"num_images": 1,
|
| 272 |
+
# "aspect_ratio": "1:1",
|
| 273 |
+
"output_format": "jpeg",
|
| 274 |
+
"safety_tolerance": 5,
|
| 275 |
+
}
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
recraftv3_pipe = PipelineParam(
|
| 279 |
+
pipeline_name='fal-ai/recraft-v3',
|
| 280 |
+
generation_path=f'generation/recraftv3',
|
| 281 |
+
base_resolution=1024,
|
| 282 |
+
generation_kwargs={
|
| 283 |
+
"enable_safety_checker": False,
|
| 284 |
+
"num_images": 1,
|
| 285 |
+
# "aspect_ratio": "1:1",
|
| 286 |
+
"output_format": "jpeg",
|
| 287 |
+
"safety_tolerance": 5,
|
| 288 |
+
}
|
| 289 |
+
)
|
| 290 |
+
|
hpsv3/cohp/utils_cohp/utils.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
try:
|
| 3 |
+
import fal_client
|
| 4 |
+
except:
|
| 5 |
+
fal_client = None
|
| 6 |
+
try:
|
| 7 |
+
from diffusers import AutoPipelineForText2Image, DiffusionPipeline
|
| 8 |
+
except:
|
| 9 |
+
AutoPipelineForText2Image = None
|
| 10 |
+
DiffusionPipeline = None
|
| 11 |
+
|
| 12 |
+
import json
|
| 13 |
+
import diffusers
|
| 14 |
+
from functools import partial
|
| 15 |
+
import os
|
| 16 |
+
# export FAL_KEY="YOUR_API_KEY"
|
| 17 |
+
os.environ['FAL_KEY'] = 'YOUR_API_KEY'
|
| 18 |
+
|
| 19 |
+
def init_pipelines(pipe_name, pipe_init_kwargs, device=None):
|
| 20 |
+
|
| 21 |
+
try:
|
| 22 |
+
pipeline = AutoPipelineForText2Image.from_pretrained(pipe_name, **pipe_init_kwargs).to(device)
|
| 23 |
+
except Exception as e:
|
| 24 |
+
# try:
|
| 25 |
+
config = json.load(open(os.path.join(pipe_name, 'model_index.json')))
|
| 26 |
+
class_name_str = config['_class_name']
|
| 27 |
+
pipeline_class = getattr(diffusers, class_name_str)
|
| 28 |
+
pipeline = pipeline_class.from_pretrained(pipe_name, **pipe_init_kwargs).to(device)
|
| 29 |
+
|
| 30 |
+
return pipeline
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def init_pipeline_from_names(pipe_names, weight_dtype):
|
| 34 |
+
pipelines_dict = {}
|
| 35 |
+
for name in pipe_names:
|
| 36 |
+
pipeline = AutoPipelineForText2Image.from_pretrained(name, torch_dtype=weight_dtype)
|
| 37 |
+
pipelines_dict[name] = pipeline
|
| 38 |
+
return pipelines_dict
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def on_queue_update(update):
|
| 42 |
+
if isinstance(update, fal_client.InProgress):
|
| 43 |
+
for log in update.logs:
|
| 44 |
+
print(log["message"])
|
| 45 |
+
|
| 46 |
+
def gen_with_api(pipe_names, generation_kwargs):
|
| 47 |
+
result = fal_client.subscribe(
|
| 48 |
+
pipe_names,
|
| 49 |
+
arguments=generation_kwargs,
|
| 50 |
+
with_logs=True,
|
| 51 |
+
on_queue_update=on_queue_update,
|
| 52 |
+
)
|
| 53 |
+
return result
|
hpsv3/config/HPSv3_7B.yaml
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Model Configuration
|
| 2 |
+
rm_head_type: "ranknet"
|
| 3 |
+
lora_enable: False
|
| 4 |
+
vision_lora: False
|
| 5 |
+
freeze_vision_tower: False
|
| 6 |
+
freeze_llm: False
|
| 7 |
+
tune_merger: True
|
| 8 |
+
model_name_or_path: "Qwen/Qwen2-VL-7B-Instruct"
|
| 9 |
+
num_lora_modules: -1
|
| 10 |
+
lora_r: 512
|
| 11 |
+
lora_alpha: 1024
|
| 12 |
+
lora_namespan_exclude: ['lm_head', 'rm_head', 'embed_tokens']
|
| 13 |
+
|
| 14 |
+
# Data Configuration
|
| 15 |
+
confidence_threshold: 0.95
|
| 16 |
+
tied_threshold: null
|
| 17 |
+
max_pixels: 200704 # 256 * 28 * 28
|
| 18 |
+
min_pixels: 200704
|
| 19 |
+
with_instruction: true
|
| 20 |
+
|
| 21 |
+
train_json_list:
|
| 22 |
+
- example_train.json
|
| 23 |
+
test_json_list:
|
| 24 |
+
- ["Valid Set 1", ["example_set_1_part1.json", "example_set_1_part2.json"]]
|
| 25 |
+
- ['Valid Set 2',["example_set_2_part1.json"]]
|
| 26 |
+
|
| 27 |
+
soft_label: False
|
| 28 |
+
output_dir: output_models
|
| 29 |
+
use_special_tokens: true
|
| 30 |
+
reward_token: "special"
|
| 31 |
+
output_dim: 2
|
| 32 |
+
loss_type: "uncertainty"
|
| 33 |
+
|
| 34 |
+
# Training Configuration
|
| 35 |
+
disable_flash_attn2: False
|
| 36 |
+
per_device_train_batch_size: 2
|
| 37 |
+
per_device_eval_batch_size: 8
|
| 38 |
+
gradient_accumulation_steps: 4
|
| 39 |
+
num_train_epochs: 10
|
| 40 |
+
learning_rate: 2.0e-6
|
| 41 |
+
special_token_lr: 2.0e-6
|
| 42 |
+
warmup_ratio: 0.05
|
| 43 |
+
lr_scheduler_type: "constant_with_warmup"
|
| 44 |
+
gradient_checkpointing: True
|
| 45 |
+
gradient_checkpointing_kwargs: {"use_reentrant": False}
|
| 46 |
+
|
| 47 |
+
# Evaluation and Logging
|
| 48 |
+
eval_strategy: "steps"
|
| 49 |
+
logging_epochs: 0.01
|
| 50 |
+
eval_epochs: 0.1
|
| 51 |
+
save_epochs: 0.1
|
| 52 |
+
report_to: tensorboard
|
| 53 |
+
|
| 54 |
+
# System Configuration
|
| 55 |
+
bf16: True
|
| 56 |
+
torch_dtype: "bfloat16"
|
| 57 |
+
deepspeed: hpsv3/config/ds_config/zero2.json
|
| 58 |
+
save_only_model: True
|
| 59 |
+
save_full_model: True
|
| 60 |
+
dataloader_num_workers: 8
|
hpsv3/config/ds_config/zero0.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"fp16": {
|
| 3 |
+
"enabled": "auto",
|
| 4 |
+
"loss_scale": 0,
|
| 5 |
+
"loss_scale_window": 1000,
|
| 6 |
+
"initial_scale_power": 16,
|
| 7 |
+
"hysteresis": 2,
|
| 8 |
+
"min_loss_scale": 1
|
| 9 |
+
},
|
| 10 |
+
"bf16": {
|
| 11 |
+
"enabled": "auto"
|
| 12 |
+
},
|
| 13 |
+
"train_micro_batch_size_per_gpu": "auto",
|
| 14 |
+
"train_batch_size": "auto",
|
| 15 |
+
"gradient_accumulation_steps": "auto",
|
| 16 |
+
"zero_optimization": {
|
| 17 |
+
"stage": 0
|
| 18 |
+
}
|
| 19 |
+
}
|
hpsv3/config/ds_config/zero2.json
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"fp16": {
|
| 3 |
+
"enabled": "auto",
|
| 4 |
+
"loss_scale": 0,
|
| 5 |
+
"loss_scale_window": 1000,
|
| 6 |
+
"initial_scale_power": 16,
|
| 7 |
+
"hysteresis": 2,
|
| 8 |
+
"min_loss_scale": 1
|
| 9 |
+
},
|
| 10 |
+
"bf16": {
|
| 11 |
+
"enabled": "auto"
|
| 12 |
+
},
|
| 13 |
+
"train_micro_batch_size_per_gpu": "auto",
|
| 14 |
+
"train_batch_size": "auto",
|
| 15 |
+
"gradient_accumulation_steps": "auto",
|
| 16 |
+
"zero_optimization": {
|
| 17 |
+
"stage": 2,
|
| 18 |
+
"overlap_comm": true,
|
| 19 |
+
"contiguous_gradients": true,
|
| 20 |
+
"sub_group_size": 1e9,
|
| 21 |
+
"reduce_bucket_size": "auto"
|
| 22 |
+
}
|
| 23 |
+
}
|
hpsv3/config/ds_config/zero3.json
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"fp16": {
|
| 3 |
+
"enabled": "auto",
|
| 4 |
+
"loss_scale": 0,
|
| 5 |
+
"loss_scale_window": 1000,
|
| 6 |
+
"initial_scale_power": 16,
|
| 7 |
+
"hysteresis": 2,
|
| 8 |
+
"min_loss_scale": 1
|
| 9 |
+
},
|
| 10 |
+
"bf16": {
|
| 11 |
+
"enabled": "auto"
|
| 12 |
+
},
|
| 13 |
+
"train_micro_batch_size_per_gpu": "auto",
|
| 14 |
+
"train_batch_size": "auto",
|
| 15 |
+
"gradient_accumulation_steps": "auto",
|
| 16 |
+
"zero_optimization": {
|
| 17 |
+
"stage": 3,
|
| 18 |
+
"overlap_comm": true,
|
| 19 |
+
"contiguous_gradients": true,
|
| 20 |
+
"sub_group_size": 1e9,
|
| 21 |
+
"reduce_bucket_size": "auto",
|
| 22 |
+
"stage3_prefetch_bucket_size": "auto",
|
| 23 |
+
"stage3_param_persistence_threshold": "auto",
|
| 24 |
+
"stage3_max_live_parameters": 1e9,
|
| 25 |
+
"stage3_max_reuse_distance": 1e9,
|
| 26 |
+
"stage3_gather_16bit_weights_on_model_save": true
|
| 27 |
+
}
|
| 28 |
+
}
|
hpsv3/dataset/data_collator_qwen.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pdb
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
from typing import Optional, List, Union
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import torch
|
| 7 |
+
from hpsv3.dataset.utils import process_vision_info
|
| 8 |
+
from torch.utils.data import Dataset
|
| 9 |
+
import torchvision.transforms.functional as F
|
| 10 |
+
|
| 11 |
+
INSTRUCTION = """
|
| 12 |
+
You are tasked with evaluating a generated image based on Visual Quality and Text Alignment and give a overall score to estimate the human preference. Please provide a rating from 0 to 10, with 0 being the worst and 10 being the best.
|
| 13 |
+
|
| 14 |
+
**Visual Quality:**
|
| 15 |
+
Evaluate the overall visual quality of the image. The following sub-dimensions should be considered:
|
| 16 |
+
- **Reasonableness:** The image should not contain any significant biological or logical errors, such as abnormal body structures or nonsensical environmental setups.
|
| 17 |
+
- **Clarity:** Evaluate the sharpness and visibility of the image. The image should be clear and easy to interpret, with no blurring or indistinct areas.
|
| 18 |
+
- **Detail Richness:** Consider the level of detail in textures, materials, lighting, and other visual elements (e.g., hair, clothing, shadows).
|
| 19 |
+
- **Aesthetic and Creativity:** Assess the artistic aspects of the image, including the color scheme, composition, atmosphere, depth of field, and the overall creative appeal. The scene should convey a sense of harmony and balance.
|
| 20 |
+
- **Safety:** The image should not contain harmful or inappropriate content, such as political, violent, or adult material. If such content is present, the image quality and satisfaction score should be the lowest possible.
|
| 21 |
+
|
| 22 |
+
**Text Alignment:**
|
| 23 |
+
Assess how well the image matches the textual prompt across the following sub-dimensions:
|
| 24 |
+
- **Subject Relevance** Evaluate how accurately the subject(s) in the image (e.g., person, animal, object) align with the textual description. The subject should match the description in terms of number, appearance, and behavior.
|
| 25 |
+
- **Style Relevance:** If the prompt specifies a particular artistic or stylistic style, evaluate how well the image adheres to this style.
|
| 26 |
+
- **Contextual Consistency**: Assess whether the background, setting, and surrounding elements in the image logically fit the scenario described in the prompt. The environment should support and enhance the subject without contradictions.
|
| 27 |
+
- **Attribute Fidelity**: Check if specific attributes mentioned in the prompt (e.g., colors, clothing, accessories, expressions, actions) are faithfully represented in the image. Minor deviations may be acceptable, but critical attributes should be preserved.
|
| 28 |
+
- **Semantic Coherence**: Evaluate whether the overall meaning and intent of the prompt are captured in the image. The generated content should not introduce elements that conflict with or distort the original description.
|
| 29 |
+
Textual prompt - {text_prompt}
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
INSTRUCTION_debug = """
|
| 35 |
+
{text_prompt}
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
prompt_with_special_token = """
|
| 39 |
+
Please provide the overall ratings of this image: <|Reward|>
|
| 40 |
+
|
| 41 |
+
END
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
prompt_without_special_token = """
|
| 45 |
+
Please provide the overall ratings of this image:
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class QWen2VLDataCollator:
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
processor,
|
| 53 |
+
with_instruction=True,
|
| 54 |
+
max_pixels=256 * 28 * 28, # Default max pixels
|
| 55 |
+
min_pixels=256 * 28 * 28, # Default min pixels
|
| 56 |
+
use_special_tokens=True,
|
| 57 |
+
):
|
| 58 |
+
self.processor = processor
|
| 59 |
+
self.with_instruction = with_instruction
|
| 60 |
+
self.max_pixels = max_pixels
|
| 61 |
+
self.min_pixels = min_pixels
|
| 62 |
+
self.use_special_tokens = use_special_tokens
|
| 63 |
+
|
| 64 |
+
def _clean_message(
|
| 65 |
+
self,
|
| 66 |
+
texts,
|
| 67 |
+
images,
|
| 68 |
+
max_pixels=256 * 28 * 28,
|
| 69 |
+
min_pixels=256 * 28 * 28,
|
| 70 |
+
with_instruction=True,
|
| 71 |
+
use_special_tokens=True,
|
| 72 |
+
):
|
| 73 |
+
"""
|
| 74 |
+
remove unnecessary keys from message(very very necessary)
|
| 75 |
+
"""
|
| 76 |
+
message_list = []
|
| 77 |
+
for text, image in zip(texts, images):
|
| 78 |
+
out_message = [
|
| 79 |
+
{
|
| 80 |
+
"role": "user",
|
| 81 |
+
"content": [
|
| 82 |
+
{
|
| 83 |
+
"type": "image",
|
| 84 |
+
"image": image,
|
| 85 |
+
"min_pixels": min_pixels,
|
| 86 |
+
"max_pixels": max_pixels,
|
| 87 |
+
},
|
| 88 |
+
{
|
| 89 |
+
"type": "text",
|
| 90 |
+
"text": (
|
| 91 |
+
INSTRUCTION.format(text_prompt=text)
|
| 92 |
+
+ prompt_with_special_token
|
| 93 |
+
if use_special_tokens
|
| 94 |
+
else prompt_without_special_token
|
| 95 |
+
),
|
| 96 |
+
},
|
| 97 |
+
],
|
| 98 |
+
}
|
| 99 |
+
]
|
| 100 |
+
|
| 101 |
+
message_list.append(out_message)
|
| 102 |
+
|
| 103 |
+
return message_list
|
| 104 |
+
|
| 105 |
+
def _pad_sequence(self, sequences, attention_mask, max_len, padding_side="right"):
|
| 106 |
+
"""
|
| 107 |
+
Pad the sequences to the maximum length.
|
| 108 |
+
"""
|
| 109 |
+
assert padding_side in ["right", "left"]
|
| 110 |
+
if sequences.shape[1] >= max_len:
|
| 111 |
+
return sequences, attention_mask
|
| 112 |
+
|
| 113 |
+
pad_len = max_len - sequences.shape[1]
|
| 114 |
+
padding = (0, pad_len) if padding_side == "right" else (pad_len, 0)
|
| 115 |
+
|
| 116 |
+
sequences_padded = torch.nn.functional.pad(
|
| 117 |
+
sequences, padding, "constant", self.processor.tokenizer.pad_token_id
|
| 118 |
+
)
|
| 119 |
+
attention_mask_padded = torch.nn.functional.pad(
|
| 120 |
+
attention_mask, padding, "constant", 0
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
return sequences_padded, attention_mask_padded
|
| 124 |
+
|
| 125 |
+
def __call__(self, inputs, with_instruction=True):
|
| 126 |
+
"""
|
| 127 |
+
Preprocess inputs to token sequences and return a batch
|
| 128 |
+
"""
|
| 129 |
+
images_1, images_2, texts_1, texts_2 = [], [], [], []
|
| 130 |
+
|
| 131 |
+
for idx, batch in enumerate(inputs):
|
| 132 |
+
texts_1.append(batch["text_1"])
|
| 133 |
+
texts_2.append(batch["text_2"])
|
| 134 |
+
images_1.append(batch["image_1"])
|
| 135 |
+
images_2.append(batch["image_2"])
|
| 136 |
+
|
| 137 |
+
messages_batch_1 = self._clean_message(
|
| 138 |
+
texts_1,
|
| 139 |
+
images_1,
|
| 140 |
+
max_pixels=self.max_pixels,
|
| 141 |
+
min_pixels=self.min_pixels,
|
| 142 |
+
with_instruction=self.with_instruction,
|
| 143 |
+
use_special_tokens=self.use_special_tokens,
|
| 144 |
+
)
|
| 145 |
+
messages_batch_2 = self._clean_message(
|
| 146 |
+
texts_2,
|
| 147 |
+
images_2,
|
| 148 |
+
max_pixels=self.max_pixels,
|
| 149 |
+
min_pixels=self.min_pixels,
|
| 150 |
+
with_instruction=self.with_instruction,
|
| 151 |
+
use_special_tokens=self.use_special_tokens,
|
| 152 |
+
)
|
| 153 |
+
# import pdb; pdb.set_trace()
|
| 154 |
+
image_inputs_1, _ = process_vision_info(messages_batch_1)
|
| 155 |
+
image_inputs_2, _ = process_vision_info(messages_batch_2)
|
| 156 |
+
image_inputs_1 = [
|
| 157 |
+
np.array(image_inputs_1[i]) / 255.0 for i in range(len(image_inputs_1))
|
| 158 |
+
]
|
| 159 |
+
image_inputs_2 = [
|
| 160 |
+
np.array(image_inputs_2[i]) / 255.0 for i in range(len(image_inputs_2))
|
| 161 |
+
]
|
| 162 |
+
do_rescale = False
|
| 163 |
+
|
| 164 |
+
batch_1 = self.processor(
|
| 165 |
+
text=self.processor.apply_chat_template(
|
| 166 |
+
messages_batch_1, tokenize=False, add_generation_prompt=True
|
| 167 |
+
),
|
| 168 |
+
images=image_inputs_1,
|
| 169 |
+
videos=None,
|
| 170 |
+
padding=True,
|
| 171 |
+
return_tensors="pt",
|
| 172 |
+
images_kwargs={"do_rescale": do_rescale},
|
| 173 |
+
)
|
| 174 |
+
batch_2 = self.processor(
|
| 175 |
+
text=self.processor.apply_chat_template(
|
| 176 |
+
messages_batch_2, tokenize=False, add_generation_prompt=True
|
| 177 |
+
),
|
| 178 |
+
images=image_inputs_2,
|
| 179 |
+
videos=None,
|
| 180 |
+
padding=True,
|
| 181 |
+
return_tensors="pt",
|
| 182 |
+
images_kwargs={"do_rescale": do_rescale},
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
# pdb.set_trace()
|
| 186 |
+
max_len = max(batch_1["input_ids"].shape[1], batch_2["input_ids"].shape[1])
|
| 187 |
+
batch_1["input_ids"], batch_1["attention_mask"] = self._pad_sequence(
|
| 188 |
+
batch_1["input_ids"], batch_1["attention_mask"], max_len, "right"
|
| 189 |
+
)
|
| 190 |
+
batch_2["input_ids"], batch_2["attention_mask"] = self._pad_sequence(
|
| 191 |
+
batch_2["input_ids"], batch_2["attention_mask"], max_len, "right"
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
batch = {
|
| 195 |
+
"batch_1": batch_1,
|
| 196 |
+
"batch_2": batch_2,
|
| 197 |
+
"choice_dist": torch.stack([batch["choice_dist"] for batch in inputs]),
|
| 198 |
+
# Store original text prompts for visualization
|
| 199 |
+
"text_1": texts_1,
|
| 200 |
+
"text_2": texts_2,
|
| 201 |
+
"image_1": image_inputs_1,
|
| 202 |
+
"image_2": image_inputs_2,
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
return batch
|
hpsv3/dataset/pairwise_dataset.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import Dataset, DataLoader
|
| 3 |
+
import random
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
|
| 8 |
+
class PairwiseOriginalDataset(Dataset):
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
json_list,
|
| 12 |
+
soft_label=False,
|
| 13 |
+
confidence_threshold=None,
|
| 14 |
+
):
|
| 15 |
+
self.samples = []
|
| 16 |
+
for json_file in json_list:
|
| 17 |
+
with open(json_file, "r") as f:
|
| 18 |
+
data = json.load(f)
|
| 19 |
+
self.samples.extend(data)
|
| 20 |
+
|
| 21 |
+
self.soft_label = soft_label
|
| 22 |
+
self.confidence_threshold = confidence_threshold
|
| 23 |
+
|
| 24 |
+
if confidence_threshold is not None:
|
| 25 |
+
new_samples = []
|
| 26 |
+
for sample in tqdm(
|
| 27 |
+
self.samples, desc="Filtering samples according to confidence threshold"
|
| 28 |
+
):
|
| 29 |
+
if sample.get("confidence", float("inf")) >= confidence_threshold:
|
| 30 |
+
new_samples.append(sample)
|
| 31 |
+
self.samples = new_samples
|
| 32 |
+
|
| 33 |
+
def __len__(self):
|
| 34 |
+
return len(self.samples)
|
| 35 |
+
|
| 36 |
+
def __getitem__(self, idx):
|
| 37 |
+
while True:
|
| 38 |
+
index = idx
|
| 39 |
+
try:
|
| 40 |
+
return self.get_single_item(index)
|
| 41 |
+
except Exception as e:
|
| 42 |
+
print(f"Error processing sample at index {idx}: {e}")
|
| 43 |
+
import traceback
|
| 44 |
+
traceback.print_exc()
|
| 45 |
+
index = random.randint(0, len(self.samples) - 1)
|
| 46 |
+
if index == idx:
|
| 47 |
+
continue
|
| 48 |
+
idx = index
|
| 49 |
+
|
| 50 |
+
def get_single_item(self, idx):
|
| 51 |
+
sample = self.samples[idx]
|
| 52 |
+
# Load image paths
|
| 53 |
+
image_1 = sample["path1"]
|
| 54 |
+
image_2 = sample["path2"]
|
| 55 |
+
assert os.path.exists(image_1) and os.path.exists(image_2), f'{image_1} or {image_2}'
|
| 56 |
+
text_1 = sample["prompt"]
|
| 57 |
+
text_2 = sample["prompt"]
|
| 58 |
+
|
| 59 |
+
# Process Label
|
| 60 |
+
if self.soft_label:
|
| 61 |
+
choice_dist = sorted(sample["choice_dist"], reverse=True)
|
| 62 |
+
assert (
|
| 63 |
+
torch.sum(torch.tensor(choice_dist)) > 0
|
| 64 |
+
), "Choice distribution cannot be zero."
|
| 65 |
+
label = torch.tensor(choice_dist[0]) / torch.sum(torch.tensor(choice_dist))
|
| 66 |
+
else:
|
| 67 |
+
label = torch.tensor(1).float()
|
| 68 |
+
# breakpoint()
|
| 69 |
+
return {
|
| 70 |
+
"image_1": image_1,
|
| 71 |
+
"image_2": image_2,
|
| 72 |
+
"text_1": text_1,
|
| 73 |
+
"text_2": text_2,
|
| 74 |
+
"label": label,
|
| 75 |
+
"confidence": sample.get("confidence", 1.0),
|
| 76 |
+
"choice_dist": torch.tensor(sample.get("choice_dist", [1.0, 0.0])),
|
| 77 |
+
}
|
hpsv3/dataset/utils.py
ADDED
|
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
## This file is modified from https://github.com/kq-chen/qwen-vl-utils/blob/main/src/qwen_vl_utils/vision_process.py
|
| 5 |
+
import base64
|
| 6 |
+
import logging
|
| 7 |
+
import math
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
import time
|
| 11 |
+
import warnings
|
| 12 |
+
from functools import lru_cache
|
| 13 |
+
from io import BytesIO
|
| 14 |
+
|
| 15 |
+
import requests
|
| 16 |
+
import torch
|
| 17 |
+
import torchvision
|
| 18 |
+
from packaging import version
|
| 19 |
+
from PIL import Image
|
| 20 |
+
from torchvision import io, transforms
|
| 21 |
+
from torchvision.transforms import InterpolationMode
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
IMAGE_FACTOR = 28
|
| 27 |
+
MIN_PIXELS = 4 * 28 * 28
|
| 28 |
+
MAX_PIXELS = 16384 * 28 * 28
|
| 29 |
+
MAX_RATIO = 200
|
| 30 |
+
|
| 31 |
+
VIDEO_MIN_PIXELS = 128 * 28 * 28
|
| 32 |
+
VIDEO_MAX_PIXELS = 768 * 28 * 28
|
| 33 |
+
VIDEO_TOTAL_PIXELS = 24576 * 28 * 28
|
| 34 |
+
FRAME_FACTOR = 2
|
| 35 |
+
FPS = 2.0
|
| 36 |
+
FPS_MIN_FRAMES = 4
|
| 37 |
+
FPS_MAX_FRAMES = 768
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def round_by_factor(number: int, factor: int) -> int:
|
| 41 |
+
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
|
| 42 |
+
return round(number / factor) * factor
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def ceil_by_factor(number: int, factor: int) -> int:
|
| 46 |
+
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
|
| 47 |
+
return math.ceil(number / factor) * factor
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def floor_by_factor(number: int, factor: int) -> int:
|
| 51 |
+
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
| 52 |
+
return math.floor(number / factor) * factor
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def smart_resize(
|
| 56 |
+
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
|
| 57 |
+
) -> tuple[int, int]:
|
| 58 |
+
"""
|
| 59 |
+
Rescales the image so that the following conditions are met:
|
| 60 |
+
|
| 61 |
+
1. Both dimensions (height and width) are divisible by 'factor'.
|
| 62 |
+
|
| 63 |
+
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
| 64 |
+
|
| 65 |
+
3. The aspect ratio of the image is maintained as closely as possible.
|
| 66 |
+
"""
|
| 67 |
+
if max(height, width) / min(height, width) > MAX_RATIO:
|
| 68 |
+
raise ValueError(
|
| 69 |
+
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
|
| 70 |
+
)
|
| 71 |
+
h_bar = max(factor, round_by_factor(height, factor))
|
| 72 |
+
w_bar = max(factor, round_by_factor(width, factor))
|
| 73 |
+
if h_bar * w_bar > max_pixels:
|
| 74 |
+
beta = math.sqrt((height * width) / max_pixels)
|
| 75 |
+
h_bar = floor_by_factor(height / beta, factor)
|
| 76 |
+
w_bar = floor_by_factor(width / beta, factor)
|
| 77 |
+
elif h_bar * w_bar < min_pixels:
|
| 78 |
+
beta = math.sqrt(min_pixels / (height * width))
|
| 79 |
+
h_bar = ceil_by_factor(height * beta, factor)
|
| 80 |
+
w_bar = ceil_by_factor(width * beta, factor)
|
| 81 |
+
return h_bar, w_bar
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def fetch_image(ele: dict[str, str | Image.Image], size_factor: int = IMAGE_FACTOR) -> Image.Image:
|
| 85 |
+
if "image" in ele:
|
| 86 |
+
image = ele["image"]
|
| 87 |
+
else:
|
| 88 |
+
image = ele["image_url"]
|
| 89 |
+
image_obj = None
|
| 90 |
+
if isinstance(image, Image.Image):
|
| 91 |
+
image_obj = image
|
| 92 |
+
elif isinstance(image, torch.Tensor):
|
| 93 |
+
image_obj = image
|
| 94 |
+
elif image.startswith("http://") or image.startswith("https://"):
|
| 95 |
+
image_obj = Image.open(requests.get(image, stream=True).raw)
|
| 96 |
+
elif image.startswith("file://"):
|
| 97 |
+
image_obj = Image.open(image[7:])
|
| 98 |
+
elif image.startswith("data:image"):
|
| 99 |
+
if "base64," in image:
|
| 100 |
+
_, base64_data = image.split("base64,", 1)
|
| 101 |
+
data = base64.b64decode(base64_data)
|
| 102 |
+
image_obj = Image.open(BytesIO(data))
|
| 103 |
+
else:
|
| 104 |
+
image_obj = Image.open(image)
|
| 105 |
+
if image_obj is None:
|
| 106 |
+
raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
|
| 107 |
+
if isinstance(image_obj, Image.Image):
|
| 108 |
+
image = image_obj.convert("RGB")
|
| 109 |
+
## resize
|
| 110 |
+
if "resized_height" in ele and "resized_width" in ele:
|
| 111 |
+
resized_height, resized_width = smart_resize(
|
| 112 |
+
ele["resized_height"],
|
| 113 |
+
ele["resized_width"],
|
| 114 |
+
factor=size_factor,
|
| 115 |
+
)
|
| 116 |
+
else:
|
| 117 |
+
if isinstance(image, torch.Tensor):
|
| 118 |
+
shape = image.shape
|
| 119 |
+
if len(shape) == 4:
|
| 120 |
+
if shape[1] in [1, 3]: # Likely [B, C, H, W]
|
| 121 |
+
height, width = shape[2], shape[3]
|
| 122 |
+
image_mode = 'NCHW'
|
| 123 |
+
elif shape[3] in [1, 3]: # Likely [B, H, W, C]
|
| 124 |
+
height, width = shape[1], shape[2]
|
| 125 |
+
image_mode = 'NHWC'
|
| 126 |
+
|
| 127 |
+
elif len(shape) == 3:
|
| 128 |
+
if shape[0] in [1, 3]: # Likely [C, H, W]
|
| 129 |
+
height, width = shape[1], shape[2]
|
| 130 |
+
image_mode = 'CHW'
|
| 131 |
+
elif shape[2] in [1, 3]: # Likely [H, W, C]
|
| 132 |
+
height, width = shape[0], shape[1]
|
| 133 |
+
image_mode = 'HWC'
|
| 134 |
+
else:
|
| 135 |
+
raise ValueError(f"Cannot determine tensor image format from shape {shape}")
|
| 136 |
+
else:
|
| 137 |
+
raise ValueError(f"Unsupported tensor image shape: {shape}")
|
| 138 |
+
else:
|
| 139 |
+
width, height = image.size
|
| 140 |
+
min_pixels = ele.get("min_pixels", MIN_PIXELS)
|
| 141 |
+
max_pixels = ele.get("max_pixels", MAX_PIXELS)
|
| 142 |
+
resized_height, resized_width = smart_resize(
|
| 143 |
+
height,
|
| 144 |
+
width,
|
| 145 |
+
factor=size_factor,
|
| 146 |
+
min_pixels=min_pixels,
|
| 147 |
+
max_pixels=max_pixels,
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
if isinstance(image, torch.Tensor):
|
| 151 |
+
if image_mode == 'NCHW':
|
| 152 |
+
image = transforms.functional.resize(
|
| 153 |
+
image, [resized_height, resized_width], interpolation=InterpolationMode.BICUBIC, antialias=True
|
| 154 |
+
)
|
| 155 |
+
elif image_mode == 'NHWC':
|
| 156 |
+
image = transforms.functional.resize(
|
| 157 |
+
image.permute(0, 3, 1, 2), [resized_height, resized_width], interpolation=InterpolationMode.BICUBIC, antialias=True
|
| 158 |
+
)
|
| 159 |
+
elif image_mode == 'CHW':
|
| 160 |
+
image = image.unsqueeze(0) # Add batch dimension
|
| 161 |
+
image = transforms.functional.resize(
|
| 162 |
+
image, [resized_height, resized_width], interpolation=InterpolationMode.BICUBIC, antialias=True
|
| 163 |
+
)
|
| 164 |
+
elif image_mode == 'HWC':
|
| 165 |
+
image = image.permute(2, 0, 1).unsqueeze(0) # Add batch dimension and change to CHW
|
| 166 |
+
image = transforms.functional.resize(
|
| 167 |
+
image, [resized_height, resized_width], interpolation=InterpolationMode.BICUBIC, antialias=True
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
else:
|
| 171 |
+
# If the image is a PIL Image, we resize it using PIL.
|
| 172 |
+
if image.mode != "RGB":
|
| 173 |
+
image = image.convert("RGB")
|
| 174 |
+
image = image.resize((resized_width, resized_height), Image.BICUBIC)
|
| 175 |
+
|
| 176 |
+
return image
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def smart_nframes(
|
| 180 |
+
ele: dict,
|
| 181 |
+
total_frames: int,
|
| 182 |
+
video_fps: int | float,
|
| 183 |
+
) -> int:
|
| 184 |
+
"""calculate the number of frames for video used for model inputs.
|
| 185 |
+
|
| 186 |
+
Args:
|
| 187 |
+
ele (dict): a dict contains the configuration of video.
|
| 188 |
+
support either `fps` or `nframes`:
|
| 189 |
+
- nframes: the number of frames to extract for model inputs.
|
| 190 |
+
- fps: the fps to extract frames for model inputs.
|
| 191 |
+
- min_frames: the minimum number of frames of the video, only used when fps is provided.
|
| 192 |
+
- max_frames: the maximum number of frames of the video, only used when fps is provided.
|
| 193 |
+
total_frames (int): the original total number of frames of the video.
|
| 194 |
+
video_fps (int | float): the original fps of the video.
|
| 195 |
+
|
| 196 |
+
Raises:
|
| 197 |
+
ValueError: nframes should in interval [FRAME_FACTOR, total_frames].
|
| 198 |
+
|
| 199 |
+
Returns:
|
| 200 |
+
int: the number of frames for video used for model inputs.
|
| 201 |
+
"""
|
| 202 |
+
assert not ("fps" in ele and "nframes" in ele), "Only accept either `fps` or `nframes`"
|
| 203 |
+
if "nframes" in ele:
|
| 204 |
+
nframes = round_by_factor(ele["nframes"], FRAME_FACTOR)
|
| 205 |
+
else:
|
| 206 |
+
fps = ele.get("fps", FPS)
|
| 207 |
+
min_frames = ceil_by_factor(ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR)
|
| 208 |
+
max_frames = floor_by_factor(ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), FRAME_FACTOR)
|
| 209 |
+
nframes = total_frames / video_fps * fps
|
| 210 |
+
nframes = min(max(nframes, min_frames), max_frames)
|
| 211 |
+
nframes = round_by_factor(nframes, FRAME_FACTOR)
|
| 212 |
+
if nframes > total_frames:
|
| 213 |
+
nframes = total_frames
|
| 214 |
+
if not (FRAME_FACTOR <= nframes and nframes <= total_frames):
|
| 215 |
+
raise ValueError(f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}.")
|
| 216 |
+
return nframes
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def _read_video_torchvision(
|
| 220 |
+
ele: dict,
|
| 221 |
+
) -> torch.Tensor:
|
| 222 |
+
"""read video using torchvision.io.read_video
|
| 223 |
+
|
| 224 |
+
Args:
|
| 225 |
+
ele (dict): a dict contains the configuration of video.
|
| 226 |
+
support keys:
|
| 227 |
+
- video: the path of video. support "file://", "http://", "https://" and local path.
|
| 228 |
+
- video_start: the start time of video.
|
| 229 |
+
- video_end: the end time of video.
|
| 230 |
+
Returns:
|
| 231 |
+
torch.Tensor: the video tensor with shape (T, C, H, W).
|
| 232 |
+
"""
|
| 233 |
+
video_path = ele["video"]
|
| 234 |
+
if version.parse(torchvision.__version__) < version.parse("0.19.0"):
|
| 235 |
+
if "http://" in video_path or "https://" in video_path:
|
| 236 |
+
warnings.warn("torchvision < 0.19.0 does not support http/https video path, please upgrade to 0.19.0.")
|
| 237 |
+
if "file://" in video_path:
|
| 238 |
+
video_path = video_path[7:]
|
| 239 |
+
st = time.time()
|
| 240 |
+
video, audio, info = io.read_video(
|
| 241 |
+
video_path,
|
| 242 |
+
start_pts=ele.get("video_start", 0.0),
|
| 243 |
+
end_pts=ele.get("video_end", None),
|
| 244 |
+
pts_unit="sec",
|
| 245 |
+
output_format="TCHW",
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
total_frames, video_fps = video.size(0), info["video_fps"]
|
| 249 |
+
# logger.info(f"torchvision: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s")
|
| 250 |
+
if ele['sample_type'] == 'uniform':
|
| 251 |
+
nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
|
| 252 |
+
idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
|
| 253 |
+
elif ele['sample_type'] == 'multi_pts':
|
| 254 |
+
frames_each_pts = 6
|
| 255 |
+
num_pts = 4
|
| 256 |
+
fps = 8
|
| 257 |
+
nframes = int(total_frames * fps // video_fps)
|
| 258 |
+
frames_idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
|
| 259 |
+
|
| 260 |
+
start_pt = int(frames_each_pts // 2)
|
| 261 |
+
end_pt = int(nframes - frames_each_pts // 2 - 1)
|
| 262 |
+
pts = torch.linspace(start_pt, end_pt, num_pts).round().long().tolist()
|
| 263 |
+
idx = []
|
| 264 |
+
for pt in pts:
|
| 265 |
+
idx.extend(frames_idx[pt - frames_each_pts // 2 : pt + frames_each_pts // 2])
|
| 266 |
+
|
| 267 |
+
video = video[idx]
|
| 268 |
+
return video
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def is_decord_available() -> bool:
|
| 272 |
+
import importlib.util
|
| 273 |
+
|
| 274 |
+
return importlib.util.find_spec("decord") is not None
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def _read_video_decord(
|
| 278 |
+
ele: dict,
|
| 279 |
+
) -> torch.Tensor:
|
| 280 |
+
"""read video using decord.VideoReader
|
| 281 |
+
|
| 282 |
+
Args:
|
| 283 |
+
ele (dict): a dict contains the configuration of video.
|
| 284 |
+
support keys:
|
| 285 |
+
- video: the path of video. support "file://", "http://", "https://" and local path.
|
| 286 |
+
- video_start: the start time of video.
|
| 287 |
+
- video_end: the end time of video.
|
| 288 |
+
Returns:
|
| 289 |
+
torch.Tensor: the video tensor with shape (T, C, H, W).
|
| 290 |
+
"""
|
| 291 |
+
import decord
|
| 292 |
+
video_path = ele["video"]
|
| 293 |
+
st = time.time()
|
| 294 |
+
vr = decord.VideoReader(video_path)
|
| 295 |
+
# TODO: support start_pts and end_pts
|
| 296 |
+
if 'video_start' in ele or 'video_end' in ele:
|
| 297 |
+
raise NotImplementedError("not support start_pts and end_pts in decord for now.")
|
| 298 |
+
total_frames, video_fps = len(vr), vr.get_avg_fps()
|
| 299 |
+
# logger.info(f"decord: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s")
|
| 300 |
+
if ele['sample_type'] == 'uniform':
|
| 301 |
+
nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
|
| 302 |
+
# nframes = max(nframes, 8)
|
| 303 |
+
# import pdb; pdb.set_trace()
|
| 304 |
+
idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
|
| 305 |
+
elif ele['sample_type'] == 'multi_pts':
|
| 306 |
+
frames_each_pts = 6
|
| 307 |
+
num_pts = 4
|
| 308 |
+
fps = 8
|
| 309 |
+
nframes = int(total_frames * fps // video_fps)
|
| 310 |
+
frames_idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
|
| 311 |
+
|
| 312 |
+
start_pt = int(frames_each_pts // 2)
|
| 313 |
+
end_pt = int(nframes - frames_each_pts // 2 - 1)
|
| 314 |
+
pts = torch.linspace(start_pt, end_pt, num_pts).round().long().tolist()
|
| 315 |
+
idx = []
|
| 316 |
+
for pt in pts:
|
| 317 |
+
idx.extend(frames_idx[pt - frames_each_pts // 2 : pt + frames_each_pts // 2])
|
| 318 |
+
video = vr.get_batch(idx).asnumpy()
|
| 319 |
+
video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format
|
| 320 |
+
return video
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
VIDEO_READER_BACKENDS = {
|
| 324 |
+
"decord": _read_video_decord,
|
| 325 |
+
"torchvision": _read_video_torchvision,
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
FORCE_QWENVL_VIDEO_READER = os.getenv("FORCE_QWENVL_VIDEO_READER", None)
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
@lru_cache(maxsize=1)
|
| 332 |
+
def get_video_reader_backend() -> str:
|
| 333 |
+
if FORCE_QWENVL_VIDEO_READER is not None:
|
| 334 |
+
video_reader_backend = FORCE_QWENVL_VIDEO_READER
|
| 335 |
+
elif is_decord_available():
|
| 336 |
+
video_reader_backend = "decord"
|
| 337 |
+
else:
|
| 338 |
+
video_reader_backend = "torchvision"
|
| 339 |
+
print(f"qwen-vl-utils using {video_reader_backend} to read video.", file=sys.stderr)
|
| 340 |
+
return video_reader_backend
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def fetch_video(ele: dict, image_factor: int = IMAGE_FACTOR) -> torch.Tensor | list[Image.Image]:
|
| 344 |
+
if isinstance(ele["video"], str):
|
| 345 |
+
video_reader_backend = get_video_reader_backend()
|
| 346 |
+
video = VIDEO_READER_BACKENDS[video_reader_backend](ele)
|
| 347 |
+
# import pdb; pdb.set_trace()
|
| 348 |
+
nframes, _, height, width = video.shape
|
| 349 |
+
|
| 350 |
+
min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS)
|
| 351 |
+
total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS)
|
| 352 |
+
max_pixels = max(min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR), int(min_pixels * 1.05))
|
| 353 |
+
max_pixels = ele.get("max_pixels", max_pixels)
|
| 354 |
+
if "resized_height" in ele and "resized_width" in ele:
|
| 355 |
+
resized_height, resized_width = smart_resize(
|
| 356 |
+
ele["resized_height"],
|
| 357 |
+
ele["resized_width"],
|
| 358 |
+
factor=image_factor,
|
| 359 |
+
)
|
| 360 |
+
else:
|
| 361 |
+
resized_height, resized_width = smart_resize(
|
| 362 |
+
height,
|
| 363 |
+
width,
|
| 364 |
+
factor=image_factor,
|
| 365 |
+
min_pixels=min_pixels,
|
| 366 |
+
max_pixels=max_pixels,
|
| 367 |
+
)
|
| 368 |
+
video = transforms.functional.resize(
|
| 369 |
+
video,
|
| 370 |
+
[resized_height, resized_width],
|
| 371 |
+
interpolation=InterpolationMode.BICUBIC,
|
| 372 |
+
antialias=True,
|
| 373 |
+
).float()
|
| 374 |
+
return video
|
| 375 |
+
else:
|
| 376 |
+
assert isinstance(ele["video"], (list, tuple))
|
| 377 |
+
process_info = ele.copy()
|
| 378 |
+
process_info.pop("type", None)
|
| 379 |
+
process_info.pop("video", None)
|
| 380 |
+
images = [
|
| 381 |
+
fetch_image({"image": video_element, **process_info}, size_factor=image_factor)
|
| 382 |
+
for video_element in ele["video"]
|
| 383 |
+
]
|
| 384 |
+
nframes = ceil_by_factor(len(images), FRAME_FACTOR)
|
| 385 |
+
if len(images) < nframes:
|
| 386 |
+
images.extend([images[-1]] * (nframes - len(images)))
|
| 387 |
+
return images
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
def extract_vision_info(conversations: list[dict] | list[list[dict]]) -> list[dict]:
|
| 391 |
+
vision_infos = []
|
| 392 |
+
if isinstance(conversations[0], dict):
|
| 393 |
+
conversations = [conversations]
|
| 394 |
+
for conversation in conversations:
|
| 395 |
+
for message in conversation:
|
| 396 |
+
if isinstance(message["content"], list):
|
| 397 |
+
for ele in message["content"]:
|
| 398 |
+
if (
|
| 399 |
+
"image" in ele
|
| 400 |
+
or "image_url" in ele
|
| 401 |
+
or "video" in ele
|
| 402 |
+
or ele["type"] in ("image", "image_url", "video")
|
| 403 |
+
):
|
| 404 |
+
vision_infos.append(ele)
|
| 405 |
+
return vision_infos
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
def process_vision_info(
|
| 409 |
+
conversations: list[dict] | list[list[dict]],
|
| 410 |
+
) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | None]:
|
| 411 |
+
vision_infos = extract_vision_info(conversations)
|
| 412 |
+
## Read images or videos
|
| 413 |
+
image_inputs = []
|
| 414 |
+
video_inputs = []
|
| 415 |
+
for vision_info in vision_infos:
|
| 416 |
+
if "image" in vision_info or "image_url" in vision_info:
|
| 417 |
+
image_inputs.append(fetch_image(vision_info))
|
| 418 |
+
elif "video" in vision_info:
|
| 419 |
+
video_inputs.append(fetch_video(vision_info))
|
| 420 |
+
else:
|
| 421 |
+
raise ValueError("image, image_url or video should in content.")
|
| 422 |
+
if len(image_inputs) == 0:
|
| 423 |
+
image_inputs = None
|
| 424 |
+
if len(video_inputs) == 0:
|
| 425 |
+
video_inputs = None
|
| 426 |
+
return image_inputs, video_inputs
|
hpsv3/inference.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import os
|
| 3 |
+
from collections.abc import Mapping
|
| 4 |
+
import torch
|
| 5 |
+
import huggingface_hub
|
| 6 |
+
from .dataset.utils import process_vision_info
|
| 7 |
+
from .dataset.data_collator_qwen import prompt_with_special_token, prompt_without_special_token, INSTRUCTION
|
| 8 |
+
from .utils.parser import ModelConfig, PEFTLoraConfig, TrainingConfig, DataConfig, parse_args_with_yaml
|
| 9 |
+
from .train import create_model_and_processor
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
_MODEL_CONFIG_PATH = Path(__file__).parent / f"config/"
|
| 13 |
+
|
| 14 |
+
class HPSv3RewardInferencer():
|
| 15 |
+
def __init__(self, config_path=None, checkpoint_path=None, device='cuda', differentiable=False):
|
| 16 |
+
if config_path is None:
|
| 17 |
+
config_path = os.path.join(_MODEL_CONFIG_PATH, 'HPSv3_7B.yaml')
|
| 18 |
+
|
| 19 |
+
if checkpoint_path is None:
|
| 20 |
+
checkpoint_path = huggingface_hub.hf_hub_download("MizzenAI/HPSv3", 'HPSv3.safetensors', repo_type='model')
|
| 21 |
+
|
| 22 |
+
(data_config, training_args, model_config, peft_lora_config), config_path = (
|
| 23 |
+
parse_args_with_yaml(
|
| 24 |
+
(DataConfig, TrainingConfig, ModelConfig, PEFTLoraConfig), config_path, is_train=False
|
| 25 |
+
)
|
| 26 |
+
)
|
| 27 |
+
training_args.output_dir = os.path.join(
|
| 28 |
+
training_args.output_dir, config_path.split("/")[-1].split(".")[0]
|
| 29 |
+
)
|
| 30 |
+
model, processor, peft_config = create_model_and_processor(
|
| 31 |
+
model_config=model_config,
|
| 32 |
+
peft_lora_config=peft_lora_config,
|
| 33 |
+
training_args=training_args,
|
| 34 |
+
differentiable=differentiable,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
self.device = device
|
| 38 |
+
self.use_special_tokens = model_config.use_special_tokens
|
| 39 |
+
|
| 40 |
+
if checkpoint_path.endswith('.safetensors'):
|
| 41 |
+
import safetensors.torch
|
| 42 |
+
state_dict = safetensors.torch.load_file(checkpoint_path, device="cpu")
|
| 43 |
+
else:
|
| 44 |
+
state_dict = torch.load(checkpoint_path , map_location="cpu")
|
| 45 |
+
|
| 46 |
+
if "model" in state_dict:
|
| 47 |
+
state_dict = state_dict["model"]
|
| 48 |
+
model.load_state_dict(state_dict, strict=True)
|
| 49 |
+
model.eval()
|
| 50 |
+
|
| 51 |
+
self.model = model
|
| 52 |
+
self.processor = processor
|
| 53 |
+
|
| 54 |
+
self.model.to(self.device)
|
| 55 |
+
self.data_config = data_config
|
| 56 |
+
|
| 57 |
+
def _pad_sequence(self, sequences, attention_mask, max_len, padding_side='right'):
|
| 58 |
+
"""
|
| 59 |
+
Pad the sequences to the maximum length.
|
| 60 |
+
"""
|
| 61 |
+
assert padding_side in ['right', 'left']
|
| 62 |
+
if sequences.shape[1] >= max_len:
|
| 63 |
+
return sequences, attention_mask
|
| 64 |
+
|
| 65 |
+
pad_len = max_len - sequences.shape[1]
|
| 66 |
+
padding = (0, pad_len) if padding_side == 'right' else (pad_len, 0)
|
| 67 |
+
|
| 68 |
+
sequences_padded = torch.nn.functional.pad(sequences, padding, 'constant', self.processor.tokenizer.pad_token_id)
|
| 69 |
+
attention_mask_padded = torch.nn.functional.pad(attention_mask, padding, 'constant', 0)
|
| 70 |
+
|
| 71 |
+
return sequences_padded, attention_mask_padded
|
| 72 |
+
|
| 73 |
+
def _prepare_input(self, data):
|
| 74 |
+
"""
|
| 75 |
+
Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and
|
| 76 |
+
handling potential state.
|
| 77 |
+
"""
|
| 78 |
+
if isinstance(data, Mapping):
|
| 79 |
+
return type(data)({k: self._prepare_input(v) for k, v in data.items()})
|
| 80 |
+
elif isinstance(data, (tuple, list)):
|
| 81 |
+
return type(data)(self._prepare_input(v) for v in data)
|
| 82 |
+
elif isinstance(data, torch.Tensor):
|
| 83 |
+
kwargs = {"device": self.device}
|
| 84 |
+
return data.to(**kwargs)
|
| 85 |
+
return data
|
| 86 |
+
|
| 87 |
+
def _prepare_inputs(self, inputs):
|
| 88 |
+
"""
|
| 89 |
+
Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and
|
| 90 |
+
handling potential state.
|
| 91 |
+
"""
|
| 92 |
+
inputs = self._prepare_input(inputs)
|
| 93 |
+
if len(inputs) == 0:
|
| 94 |
+
raise ValueError
|
| 95 |
+
return inputs
|
| 96 |
+
|
| 97 |
+
def prepare_batch(self, image_paths, prompts):
|
| 98 |
+
max_pixels = 256 * 28 * 28
|
| 99 |
+
min_pixels = 256 * 28 * 28
|
| 100 |
+
message_list = []
|
| 101 |
+
for text, image in zip(prompts, image_paths):
|
| 102 |
+
out_message = [
|
| 103 |
+
{
|
| 104 |
+
"role": "user",
|
| 105 |
+
"content": [
|
| 106 |
+
{
|
| 107 |
+
"type": "image",
|
| 108 |
+
"image": image,
|
| 109 |
+
"min_pixels": max_pixels,
|
| 110 |
+
"max_pixels": max_pixels,
|
| 111 |
+
},
|
| 112 |
+
{
|
| 113 |
+
"type": "text",
|
| 114 |
+
"text": (
|
| 115 |
+
INSTRUCTION.format(text_prompt=text)
|
| 116 |
+
+ prompt_with_special_token
|
| 117 |
+
if self.use_special_tokens
|
| 118 |
+
else prompt_without_special_token
|
| 119 |
+
),
|
| 120 |
+
},
|
| 121 |
+
],
|
| 122 |
+
}
|
| 123 |
+
]
|
| 124 |
+
|
| 125 |
+
message_list.append(out_message)
|
| 126 |
+
|
| 127 |
+
image_inputs, _ = process_vision_info(message_list)
|
| 128 |
+
|
| 129 |
+
batch = self.processor(
|
| 130 |
+
text=self.processor.apply_chat_template(message_list, tokenize=False, add_generation_prompt=True),
|
| 131 |
+
images=image_inputs,
|
| 132 |
+
padding=True,
|
| 133 |
+
return_tensors="pt",
|
| 134 |
+
videos_kwargs={"do_rescale": True},
|
| 135 |
+
)
|
| 136 |
+
batch = self._prepare_inputs(batch)
|
| 137 |
+
return batch
|
| 138 |
+
|
| 139 |
+
def reward(self, image_paths, prompts):
|
| 140 |
+
|
| 141 |
+
batch = self.prepare_batch(image_paths, prompts)
|
| 142 |
+
rewards = self.model(
|
| 143 |
+
return_dict=True,
|
| 144 |
+
**batch
|
| 145 |
+
)["logits"]
|
| 146 |
+
|
| 147 |
+
return rewards
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
if __name__ == "__main__":
|
| 151 |
+
config_path = 'config/inference/HPSv3_7B.yaml'
|
| 152 |
+
checkpoint_path = 'checkpoints/HPSv3_7B.pth'
|
| 153 |
+
device = 'cuda'
|
| 154 |
+
dtype = torch.bfloat16
|
| 155 |
+
inferencer = HPSv3RewardInferencer(config_path, checkpoint_path, device=device)
|
| 156 |
+
|
| 157 |
+
image_paths = [
|
| 158 |
+
"assets/example1.png",
|
| 159 |
+
"assets/example2.png"
|
| 160 |
+
]
|
| 161 |
+
prompts = [
|
| 162 |
+
"cute chibi anime cartoon fox, smiling wagging tail with a small cartoon heart above sticker",
|
| 163 |
+
"cute chibi anime cartoon fox, smiling wagging tail with a small cartoon heart above sticker"
|
| 164 |
+
]
|
| 165 |
+
rewards = inferencer.reward(image_paths, prompts)
|
| 166 |
+
print(rewards[0][0].item()) # miu and sigma. we select miu as the final output
|
| 167 |
+
print(rewards[1][0].item())
|
hpsv3/model/differentiable_image_processor.py
ADDED
|
@@ -0,0 +1,629 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 5 |
+
# and OPT implementations in this library. It has been modified from its
|
| 6 |
+
# original forms to accommodate minor architectural differences compared
|
| 7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 8 |
+
#
|
| 9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 10 |
+
# you may not use this file except in compliance with the License.
|
| 11 |
+
# You may obtain a copy of the License at
|
| 12 |
+
#
|
| 13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 14 |
+
#
|
| 15 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 18 |
+
# See the License for the specific language governing permissions and
|
| 19 |
+
# limitations under the License.
|
| 20 |
+
"""Image processor class for Qwen2-VL.
|
| 21 |
+
|
| 22 |
+
This module provides both differentiable and non-differentiable image processing methods:
|
| 23 |
+
|
| 24 |
+
1. For DIFFERENTIABLE processing (torch.autograd compatible):
|
| 25 |
+
- Pass torch.Tensor to _preprocess() method
|
| 26 |
+
- Use preprocess_tensor() method directly
|
| 27 |
+
- All operations use PyTorch functions (F.interpolate, tensor operations, etc.)
|
| 28 |
+
|
| 29 |
+
2. For NON-DIFFERENTIABLE processing (original functionality):
|
| 30 |
+
- Pass PIL images or numpy arrays to preprocess() method
|
| 31 |
+
- Uses PIL/transformers image processing functions and numpy operations
|
| 32 |
+
|
| 33 |
+
The differentiable path supports:
|
| 34 |
+
- Bilinear interpolation for resizing (instead of PIL resampling)
|
| 35 |
+
- Tensor-based rescaling and normalization
|
| 36 |
+
- Differentiable patch extraction and reshaping
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
import math
|
| 40 |
+
from typing import Dict, List, Optional, Union
|
| 41 |
+
|
| 42 |
+
import numpy as np
|
| 43 |
+
import torch
|
| 44 |
+
import torch.nn.functional as F
|
| 45 |
+
|
| 46 |
+
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
|
| 47 |
+
from transformers.image_transforms import (
|
| 48 |
+
convert_to_rgb,
|
| 49 |
+
resize,
|
| 50 |
+
to_channel_dimension_format,
|
| 51 |
+
)
|
| 52 |
+
from transformers.image_utils import (
|
| 53 |
+
OPENAI_CLIP_MEAN,
|
| 54 |
+
OPENAI_CLIP_STD,
|
| 55 |
+
ChannelDimension,
|
| 56 |
+
ImageInput,
|
| 57 |
+
PILImageResampling,
|
| 58 |
+
VideoInput,
|
| 59 |
+
get_image_size,
|
| 60 |
+
infer_channel_dimension_format,
|
| 61 |
+
is_scaled_image,
|
| 62 |
+
is_valid_image,
|
| 63 |
+
make_list_of_images,
|
| 64 |
+
to_numpy_array,
|
| 65 |
+
valid_images,
|
| 66 |
+
validate_preprocess_arguments,
|
| 67 |
+
)
|
| 68 |
+
from transformers.utils import TensorType, is_vision_available, logging
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
logger = logging.get_logger(__name__)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
if is_vision_available():
|
| 75 |
+
from PIL import Image
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def make_batched_images(images) -> List[List[ImageInput]]:
|
| 79 |
+
"""
|
| 80 |
+
Accepts images in list or nested list format, and makes a list of images for preprocessing.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`):
|
| 84 |
+
The input image.
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
list: A list of images.
|
| 88 |
+
"""
|
| 89 |
+
if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]):
|
| 90 |
+
return [img for img_list in images for img in img_list]
|
| 91 |
+
|
| 92 |
+
elif isinstance(images, (list, tuple)) and is_valid_image(images[0]):
|
| 93 |
+
return images
|
| 94 |
+
|
| 95 |
+
elif is_valid_image(images):
|
| 96 |
+
return [images]
|
| 97 |
+
|
| 98 |
+
raise ValueError(f"Could not make batched images from {images}")
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
# Copied from transformers.models.llava_next_video.image_processing_llava_next_video.make_batched_videos
|
| 102 |
+
def make_batched_videos(videos) -> List[VideoInput]:
|
| 103 |
+
if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]):
|
| 104 |
+
return videos
|
| 105 |
+
|
| 106 |
+
elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]):
|
| 107 |
+
if isinstance(videos[0], Image.Image):
|
| 108 |
+
return [videos]
|
| 109 |
+
elif len(videos[0].shape) == 4:
|
| 110 |
+
return [list(video) for video in videos]
|
| 111 |
+
|
| 112 |
+
elif is_valid_image(videos) and len(videos.shape) == 4:
|
| 113 |
+
return [list(videos)]
|
| 114 |
+
|
| 115 |
+
raise ValueError(f"Could not make batched video from {videos}")
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def smart_resize(
|
| 119 |
+
height: int, width: int, factor: int = 28, min_pixels: int = 56 * 56, max_pixels: int = 14 * 14 * 4 * 1280
|
| 120 |
+
):
|
| 121 |
+
"""Rescales the image so that the following conditions are met:
|
| 122 |
+
|
| 123 |
+
1. Both dimensions (height and width) are divisible by 'factor'.
|
| 124 |
+
|
| 125 |
+
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
| 126 |
+
|
| 127 |
+
3. The aspect ratio of the image is maintained as closely as possible.
|
| 128 |
+
|
| 129 |
+
"""
|
| 130 |
+
if height < factor or width < factor:
|
| 131 |
+
raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
|
| 132 |
+
elif max(height, width) / min(height, width) > 200:
|
| 133 |
+
raise ValueError(
|
| 134 |
+
f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
|
| 135 |
+
)
|
| 136 |
+
h_bar = round(height / factor) * factor
|
| 137 |
+
w_bar = round(width / factor) * factor
|
| 138 |
+
if h_bar * w_bar > max_pixels:
|
| 139 |
+
beta = math.sqrt((height * width) / max_pixels)
|
| 140 |
+
h_bar = math.floor(height / beta / factor) * factor
|
| 141 |
+
w_bar = math.floor(width / beta / factor) * factor
|
| 142 |
+
elif h_bar * w_bar < min_pixels:
|
| 143 |
+
beta = math.sqrt(min_pixels / (height * width))
|
| 144 |
+
h_bar = math.ceil(height * beta / factor) * factor
|
| 145 |
+
w_bar = math.ceil(width * beta / factor) * factor
|
| 146 |
+
return h_bar, w_bar
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class Qwen2VLImageProcessor(BaseImageProcessor):
|
| 150 |
+
r"""
|
| 151 |
+
Constructs a Qwen2-VL image processor that dynamically resizes images based on the original images.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
| 155 |
+
Whether to resize the image's (height, width) dimensions.
|
| 156 |
+
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
|
| 157 |
+
Resampling filter to use when resizing the image.
|
| 158 |
+
do_rescale (`bool`, *optional*, defaults to `True`):
|
| 159 |
+
Whether to rescale the image by the specified scale `rescale_factor`.
|
| 160 |
+
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
| 161 |
+
Scale factor to use if rescaling the image.
|
| 162 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
| 163 |
+
Whether to normalize the image.
|
| 164 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
|
| 165 |
+
Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
|
| 166 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
|
| 167 |
+
Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
|
| 168 |
+
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
| 169 |
+
Whether to convert the image to RGB.
|
| 170 |
+
min_pixels (`int`, *optional*, defaults to `56 * 56`):
|
| 171 |
+
The min pixels of the image to resize the image.
|
| 172 |
+
max_pixels (`int`, *optional*, defaults to `28 * 28 * 1280`):
|
| 173 |
+
The max pixels of the image to resize the image.
|
| 174 |
+
patch_size (`int`, *optional*, defaults to 14):
|
| 175 |
+
The spacial patch size of the vision encoder.
|
| 176 |
+
temporal_patch_size (`int`, *optional*, defaults to 2):
|
| 177 |
+
The temporal patch size of the vision encoder.
|
| 178 |
+
merge_size (`int`, *optional*, defaults to 2):
|
| 179 |
+
The merge size of the vision encoder to llm encoder.
|
| 180 |
+
"""
|
| 181 |
+
|
| 182 |
+
model_input_names = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"]
|
| 183 |
+
|
| 184 |
+
def __init__(
|
| 185 |
+
self,
|
| 186 |
+
do_resize: bool = True,
|
| 187 |
+
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
| 188 |
+
do_rescale: bool = True,
|
| 189 |
+
rescale_factor: Union[int, float] = 1 / 255,
|
| 190 |
+
do_normalize: bool = True,
|
| 191 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 192 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 193 |
+
do_convert_rgb: bool = True,
|
| 194 |
+
min_pixels: int = 56 * 56,
|
| 195 |
+
max_pixels: int = 28 * 28 * 1280,
|
| 196 |
+
patch_size: int = 14,
|
| 197 |
+
temporal_patch_size: int = 2,
|
| 198 |
+
merge_size: int = 2,
|
| 199 |
+
**kwargs,
|
| 200 |
+
) -> None:
|
| 201 |
+
super().__init__(**kwargs)
|
| 202 |
+
self.do_resize = do_resize
|
| 203 |
+
self.resample = resample
|
| 204 |
+
self.do_rescale = do_rescale
|
| 205 |
+
self.rescale_factor = rescale_factor
|
| 206 |
+
self.do_normalize = do_normalize
|
| 207 |
+
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
|
| 208 |
+
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
|
| 209 |
+
self.min_pixels = min_pixels
|
| 210 |
+
self.max_pixels = max_pixels
|
| 211 |
+
self.patch_size = patch_size
|
| 212 |
+
self.temporal_patch_size = temporal_patch_size
|
| 213 |
+
self.merge_size = merge_size
|
| 214 |
+
self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels}
|
| 215 |
+
self.do_convert_rgb = do_convert_rgb
|
| 216 |
+
|
| 217 |
+
def _preprocess_differentiable(
|
| 218 |
+
self,
|
| 219 |
+
images: torch.Tensor,
|
| 220 |
+
do_resize: bool = None,
|
| 221 |
+
do_rescale: bool = None,
|
| 222 |
+
rescale_factor: float = None,
|
| 223 |
+
do_normalize: bool = None,
|
| 224 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 225 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 226 |
+
):
|
| 227 |
+
"""
|
| 228 |
+
Differentiable version of image preprocessing using torch operations.
|
| 229 |
+
|
| 230 |
+
Args:
|
| 231 |
+
images: torch.Tensor of shape (B, C, H, W) or (C, H, W)
|
| 232 |
+
Returns:
|
| 233 |
+
flatten_patches: torch.Tensor - flattened patches
|
| 234 |
+
grid_thw: tuple - (grid_t, grid_h, grid_w)
|
| 235 |
+
"""
|
| 236 |
+
if images.dim() == 3:
|
| 237 |
+
images = images.unsqueeze(0) # Add batch dimension
|
| 238 |
+
|
| 239 |
+
batch_size, channels, height, width = images.shape
|
| 240 |
+
|
| 241 |
+
processed_images = []
|
| 242 |
+
resized_height, resized_width = height, width
|
| 243 |
+
|
| 244 |
+
for i in range(batch_size):
|
| 245 |
+
image = images[i] # (C, H, W)
|
| 246 |
+
|
| 247 |
+
if do_resize:
|
| 248 |
+
resized_height, resized_width = smart_resize(
|
| 249 |
+
height,
|
| 250 |
+
width,
|
| 251 |
+
factor=self.patch_size * self.merge_size,
|
| 252 |
+
min_pixels=self.min_pixels,
|
| 253 |
+
max_pixels=self.max_pixels,
|
| 254 |
+
)
|
| 255 |
+
# Use differentiable interpolation
|
| 256 |
+
image = F.interpolate(
|
| 257 |
+
image.unsqueeze(0),
|
| 258 |
+
size=(resized_height, resized_width),
|
| 259 |
+
mode='bilinear',
|
| 260 |
+
align_corners=False
|
| 261 |
+
).squeeze(0)
|
| 262 |
+
|
| 263 |
+
if do_rescale:
|
| 264 |
+
image = image * rescale_factor
|
| 265 |
+
|
| 266 |
+
if do_normalize:
|
| 267 |
+
if isinstance(image_mean, (list, tuple)):
|
| 268 |
+
mean = torch.tensor(image_mean, device=image.device, dtype=image.dtype).view(-1, 1, 1)
|
| 269 |
+
std = torch.tensor(image_std, device=image.device, dtype=image.dtype).view(-1, 1, 1)
|
| 270 |
+
else:
|
| 271 |
+
mean = image_mean
|
| 272 |
+
std = image_std
|
| 273 |
+
image = (image - mean) / std
|
| 274 |
+
|
| 275 |
+
processed_images.append(image)
|
| 276 |
+
|
| 277 |
+
# Stack all processed images
|
| 278 |
+
patches = torch.stack(processed_images) # (B, C, H, W)
|
| 279 |
+
|
| 280 |
+
# Handle temporal dimension
|
| 281 |
+
if patches.shape[0] == 1:
|
| 282 |
+
patches = patches.repeat(self.temporal_patch_size, 1, 1, 1)
|
| 283 |
+
|
| 284 |
+
# Reshape for patch extraction
|
| 285 |
+
batch_size, channel, resized_height, resized_width = patches.shape
|
| 286 |
+
grid_t = batch_size // self.temporal_patch_size
|
| 287 |
+
grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
|
| 288 |
+
|
| 289 |
+
# Differentiable patch extraction and reshaping
|
| 290 |
+
patches = patches.view(
|
| 291 |
+
grid_t,
|
| 292 |
+
self.temporal_patch_size,
|
| 293 |
+
channel,
|
| 294 |
+
grid_h // self.merge_size,
|
| 295 |
+
self.merge_size,
|
| 296 |
+
self.patch_size,
|
| 297 |
+
grid_w // self.merge_size,
|
| 298 |
+
self.merge_size,
|
| 299 |
+
self.patch_size,
|
| 300 |
+
)
|
| 301 |
+
patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8)
|
| 302 |
+
flatten_patches = patches.reshape(
|
| 303 |
+
grid_t * grid_h * grid_w, channel * self.temporal_patch_size * self.patch_size * self.patch_size
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
return flatten_patches, (grid_t, grid_h, grid_w)
|
| 307 |
+
|
| 308 |
+
def _preprocess(
|
| 309 |
+
self,
|
| 310 |
+
images: Union[ImageInput, VideoInput],
|
| 311 |
+
do_resize: bool = None,
|
| 312 |
+
resample: PILImageResampling = None,
|
| 313 |
+
do_rescale: bool = None,
|
| 314 |
+
rescale_factor: float = None,
|
| 315 |
+
do_normalize: bool = None,
|
| 316 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 317 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 318 |
+
do_convert_rgb: bool = None,
|
| 319 |
+
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
| 320 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 321 |
+
):
|
| 322 |
+
"""
|
| 323 |
+
Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
|
| 324 |
+
|
| 325 |
+
Args:
|
| 326 |
+
images (`ImageInput`):
|
| 327 |
+
Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
|
| 328 |
+
vision_info (`List[Dict]`, *optional*):
|
| 329 |
+
Optional list of dictionaries containing additional information about vision inputs.
|
| 330 |
+
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
| 331 |
+
Whether to resize the image.
|
| 332 |
+
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
| 333 |
+
Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums.
|
| 334 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
| 335 |
+
Whether to rescale the image.
|
| 336 |
+
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
| 337 |
+
Scale factor to use if rescaling the image.
|
| 338 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
| 339 |
+
Whether to normalize the image.
|
| 340 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
| 341 |
+
Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
|
| 342 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
| 343 |
+
Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
|
| 344 |
+
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
| 345 |
+
Whether to convert the image to RGB.
|
| 346 |
+
data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
|
| 347 |
+
The channel dimension format for the output image. Can be one of:
|
| 348 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 349 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 350 |
+
- Unset: Use the channel dimension format of the input image.
|
| 351 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 352 |
+
The channel dimension format for the input image. Can be one of:
|
| 353 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 354 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 355 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 356 |
+
"""
|
| 357 |
+
# Check if input is already a torch tensor (differentiable path)
|
| 358 |
+
if isinstance(images, torch.Tensor):
|
| 359 |
+
return self._preprocess_differentiable(
|
| 360 |
+
images,
|
| 361 |
+
do_resize=do_resize,
|
| 362 |
+
do_rescale=do_rescale,
|
| 363 |
+
rescale_factor=rescale_factor,
|
| 364 |
+
do_normalize=do_normalize,
|
| 365 |
+
image_mean=image_mean,
|
| 366 |
+
image_std=image_std,
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
# Original non-differentiable path for backward compatibility
|
| 370 |
+
images = make_list_of_images(images)
|
| 371 |
+
|
| 372 |
+
if do_convert_rgb:
|
| 373 |
+
images = [convert_to_rgb(image) for image in images]
|
| 374 |
+
|
| 375 |
+
# All transformations expect numpy arrays.
|
| 376 |
+
images = [to_numpy_array(image) for image in images]
|
| 377 |
+
|
| 378 |
+
if is_scaled_image(images[0]) and do_rescale:
|
| 379 |
+
logger.warning_once(
|
| 380 |
+
"It looks like you are trying to rescale already rescaled images. If the input"
|
| 381 |
+
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
| 382 |
+
)
|
| 383 |
+
if input_data_format is None:
|
| 384 |
+
# We assume that all images have the same channel dimension format.
|
| 385 |
+
input_data_format = infer_channel_dimension_format(images[0])
|
| 386 |
+
|
| 387 |
+
height, width = get_image_size(images[0], channel_dim=input_data_format)
|
| 388 |
+
resized_height, resized_width = height, width
|
| 389 |
+
processed_images = []
|
| 390 |
+
for image in images:
|
| 391 |
+
if do_resize:
|
| 392 |
+
resized_height, resized_width = smart_resize(
|
| 393 |
+
height,
|
| 394 |
+
width,
|
| 395 |
+
factor=self.patch_size * self.merge_size,
|
| 396 |
+
min_pixels=self.min_pixels,
|
| 397 |
+
max_pixels=self.max_pixels,
|
| 398 |
+
)
|
| 399 |
+
image = resize(
|
| 400 |
+
image, size=(resized_height, resized_width), resample=resample, input_data_format=input_data_format
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
if do_rescale:
|
| 404 |
+
image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format)
|
| 405 |
+
|
| 406 |
+
if do_normalize:
|
| 407 |
+
image = self.normalize(
|
| 408 |
+
image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
| 412 |
+
processed_images.append(image)
|
| 413 |
+
|
| 414 |
+
# NOTE: The following operations use numpy and are NOT differentiable
|
| 415 |
+
# For differentiable operations, pass torch.Tensor as input to use _preprocess_differentiable
|
| 416 |
+
patches = np.array(processed_images)
|
| 417 |
+
if data_format == ChannelDimension.LAST:
|
| 418 |
+
patches = patches.transpose(0, 3, 1, 2)
|
| 419 |
+
if patches.shape[0] == 1:
|
| 420 |
+
patches = np.tile(patches, (self.temporal_patch_size, 1, 1, 1))
|
| 421 |
+
channel = patches.shape[1]
|
| 422 |
+
grid_t = patches.shape[0] // self.temporal_patch_size
|
| 423 |
+
grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
|
| 424 |
+
patches = patches.reshape(
|
| 425 |
+
grid_t,
|
| 426 |
+
self.temporal_patch_size,
|
| 427 |
+
channel,
|
| 428 |
+
grid_h // self.merge_size,
|
| 429 |
+
self.merge_size,
|
| 430 |
+
self.patch_size,
|
| 431 |
+
grid_w // self.merge_size,
|
| 432 |
+
self.merge_size,
|
| 433 |
+
self.patch_size,
|
| 434 |
+
)
|
| 435 |
+
patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8)
|
| 436 |
+
flatten_patches = patches.reshape(
|
| 437 |
+
grid_t * grid_h * grid_w, channel * self.temporal_patch_size * self.patch_size * self.patch_size
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
return flatten_patches, (grid_t, grid_h, grid_w)
|
| 441 |
+
|
| 442 |
+
def preprocess_tensor(
|
| 443 |
+
self,
|
| 444 |
+
images: torch.Tensor,
|
| 445 |
+
do_resize: bool = None,
|
| 446 |
+
do_rescale: bool = None,
|
| 447 |
+
rescale_factor: float = None,
|
| 448 |
+
do_normalize: bool = None,
|
| 449 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 450 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 451 |
+
):
|
| 452 |
+
"""
|
| 453 |
+
Differentiable preprocessing method for torch tensors.
|
| 454 |
+
|
| 455 |
+
Args:
|
| 456 |
+
images: torch.Tensor of shape (B, C, H, W) or (C, H, W)
|
| 457 |
+
|
| 458 |
+
Returns:
|
| 459 |
+
dict containing:
|
| 460 |
+
- pixel_values: torch.Tensor - processed patches
|
| 461 |
+
- image_grid_thw: torch.Tensor - grid dimensions
|
| 462 |
+
"""
|
| 463 |
+
do_resize = do_resize if do_resize is not None else self.do_resize
|
| 464 |
+
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
| 465 |
+
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
| 466 |
+
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
| 467 |
+
image_mean = image_mean if image_mean is not None else self.image_mean
|
| 468 |
+
image_std = image_std if image_std is not None else self.image_std
|
| 469 |
+
|
| 470 |
+
patches, image_grid_thw = self._preprocess_differentiable(
|
| 471 |
+
images,
|
| 472 |
+
do_resize=do_resize,
|
| 473 |
+
do_rescale=do_rescale,
|
| 474 |
+
rescale_factor=rescale_factor,
|
| 475 |
+
do_normalize=do_normalize,
|
| 476 |
+
image_mean=image_mean,
|
| 477 |
+
image_std=image_std,
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
return {
|
| 481 |
+
"pixel_values": patches,
|
| 482 |
+
"image_grid_thw": torch.tensor(image_grid_thw, device=patches.device)
|
| 483 |
+
}
|
| 484 |
+
|
| 485 |
+
def preprocess(
|
| 486 |
+
self,
|
| 487 |
+
images: ImageInput,
|
| 488 |
+
videos: VideoInput = None,
|
| 489 |
+
do_resize: bool = None,
|
| 490 |
+
size: Dict[str, int] = None,
|
| 491 |
+
resample: PILImageResampling = None,
|
| 492 |
+
do_rescale: bool = None,
|
| 493 |
+
rescale_factor: float = None,
|
| 494 |
+
do_normalize: bool = None,
|
| 495 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 496 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 497 |
+
do_convert_rgb: bool = None,
|
| 498 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 499 |
+
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
| 500 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 501 |
+
):
|
| 502 |
+
"""
|
| 503 |
+
Args:
|
| 504 |
+
images (`ImageInput`):
|
| 505 |
+
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
| 506 |
+
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
| 507 |
+
videos (`VideoInput`):
|
| 508 |
+
Video to preprocess. Expects a single or batch of videos with pixel values ranging from 0 to 255. If
|
| 509 |
+
passing in videos with pixel values between 0 and 1, set `do_rescale=False`.
|
| 510 |
+
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
| 511 |
+
Whether to resize the image.
|
| 512 |
+
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
| 513 |
+
Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
|
| 514 |
+
the longest edge resized to keep the input aspect ratio.
|
| 515 |
+
resample (`int`, *optional*, defaults to `self.resample`):
|
| 516 |
+
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
|
| 517 |
+
has an effect if `do_resize` is set to `True`.
|
| 518 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
| 519 |
+
Whether to rescale the image.
|
| 520 |
+
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
| 521 |
+
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
| 522 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
| 523 |
+
Whether to normalize the image.
|
| 524 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
| 525 |
+
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
| 526 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
| 527 |
+
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
| 528 |
+
`True`.
|
| 529 |
+
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
| 530 |
+
Whether to convert the image to RGB.
|
| 531 |
+
return_tensors (`str` or `TensorType`, *optional*):
|
| 532 |
+
The type of tensors to return. Can be one of:
|
| 533 |
+
- Unset: Return a list of `np.ndarray`.
|
| 534 |
+
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
| 535 |
+
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
| 536 |
+
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
| 537 |
+
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
| 538 |
+
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
| 539 |
+
The channel dimension format for the output image. Can be one of:
|
| 540 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 541 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 542 |
+
- Unset: Use the channel dimension format of the input image.
|
| 543 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 544 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
| 545 |
+
from the input image. Can be one of:
|
| 546 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 547 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 548 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 549 |
+
|
| 550 |
+
"""
|
| 551 |
+
do_resize = do_resize if do_resize is not None else self.do_resize
|
| 552 |
+
size = size if size is not None else self.size
|
| 553 |
+
resample = resample if resample is not None else self.resample
|
| 554 |
+
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
| 555 |
+
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
| 556 |
+
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
| 557 |
+
image_mean = image_mean if image_mean is not None else self.image_mean
|
| 558 |
+
image_std = image_std if image_std is not None else self.image_std
|
| 559 |
+
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
| 560 |
+
|
| 561 |
+
if images is not None:
|
| 562 |
+
images = make_batched_images(images)
|
| 563 |
+
if videos is not None:
|
| 564 |
+
videos = make_batched_videos(videos)
|
| 565 |
+
|
| 566 |
+
if images is not None and not valid_images(images):
|
| 567 |
+
raise ValueError(
|
| 568 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
| 569 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
| 570 |
+
)
|
| 571 |
+
|
| 572 |
+
validate_preprocess_arguments(
|
| 573 |
+
rescale_factor=rescale_factor,
|
| 574 |
+
do_normalize=do_normalize,
|
| 575 |
+
image_mean=image_mean,
|
| 576 |
+
image_std=image_std,
|
| 577 |
+
do_resize=do_resize,
|
| 578 |
+
size=size,
|
| 579 |
+
resample=resample,
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
if images is not None:
|
| 583 |
+
pixel_values, vision_grid_thws = [], []
|
| 584 |
+
for image in images:
|
| 585 |
+
patches, image_grid_thw = self._preprocess(
|
| 586 |
+
image,
|
| 587 |
+
do_resize=do_resize,
|
| 588 |
+
resample=resample,
|
| 589 |
+
do_rescale=do_rescale,
|
| 590 |
+
rescale_factor=rescale_factor,
|
| 591 |
+
do_normalize=do_normalize,
|
| 592 |
+
image_mean=image_mean,
|
| 593 |
+
image_std=image_std,
|
| 594 |
+
data_format=data_format,
|
| 595 |
+
do_convert_rgb=do_convert_rgb,
|
| 596 |
+
input_data_format=input_data_format,
|
| 597 |
+
)
|
| 598 |
+
pixel_values.extend(patches)
|
| 599 |
+
vision_grid_thws.append(image_grid_thw)
|
| 600 |
+
if not isinstance(pixel_values[0], torch.Tensor):
|
| 601 |
+
pixel_values = np.array(pixel_values)
|
| 602 |
+
else:
|
| 603 |
+
pixel_values = torch.stack(pixel_values)
|
| 604 |
+
vision_grid_thws = np.array(vision_grid_thws)
|
| 605 |
+
data = {"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws}
|
| 606 |
+
|
| 607 |
+
if videos is not None:
|
| 608 |
+
pixel_values, vision_grid_thws = [], []
|
| 609 |
+
for images in videos:
|
| 610 |
+
patches, video_grid_thw = self._preprocess(
|
| 611 |
+
images,
|
| 612 |
+
do_resize=do_resize,
|
| 613 |
+
resample=resample,
|
| 614 |
+
do_rescale=do_rescale,
|
| 615 |
+
rescale_factor=rescale_factor,
|
| 616 |
+
do_normalize=do_normalize,
|
| 617 |
+
image_mean=image_mean,
|
| 618 |
+
image_std=image_std,
|
| 619 |
+
data_format=data_format,
|
| 620 |
+
do_convert_rgb=do_convert_rgb,
|
| 621 |
+
input_data_format=input_data_format,
|
| 622 |
+
)
|
| 623 |
+
pixel_values.extend(patches)
|
| 624 |
+
vision_grid_thws.append(video_grid_thw)
|
| 625 |
+
pixel_values = np.array(pixel_values)
|
| 626 |
+
vision_grid_thws = np.array(vision_grid_thws)
|
| 627 |
+
data = {"pixel_values_videos": pixel_values, "video_grid_thw": vision_grid_thws}
|
| 628 |
+
|
| 629 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
hpsv3/model/qwen2vl_trainer.py
ADDED
|
@@ -0,0 +1,971 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pdb
|
| 3 |
+
import warnings
|
| 4 |
+
import time
|
| 5 |
+
import math
|
| 6 |
+
import json
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
import matplotlib.patches as patches
|
| 10 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 11 |
+
import torchvision.transforms as transforms
|
| 12 |
+
|
| 13 |
+
from typing import List, Optional, Dict, Union, Any
|
| 14 |
+
import pandas as pd
|
| 15 |
+
import safetensors
|
| 16 |
+
import numpy as np
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
import datasets
|
| 20 |
+
from torch.utils.data import Dataset, DataLoader
|
| 21 |
+
from peft import PeftModel
|
| 22 |
+
from transformers import Qwen2VLForConditionalGeneration
|
| 23 |
+
from transformers import AutoConfig
|
| 24 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 25 |
+
from transformers.trainer import TrainerCallback
|
| 26 |
+
from transformers.trainer import (
|
| 27 |
+
is_sagemaker_mp_enabled,
|
| 28 |
+
is_peft_available,
|
| 29 |
+
is_datasets_available,
|
| 30 |
+
WEIGHTS_NAME,
|
| 31 |
+
TRAINING_ARGS_NAME,
|
| 32 |
+
SAFE_WEIGHTS_NAME,
|
| 33 |
+
TRAINER_STATE_NAME,
|
| 34 |
+
PREFIX_CHECKPOINT_DIR,
|
| 35 |
+
logger,
|
| 36 |
+
speed_metrics,
|
| 37 |
+
deepspeed_init,
|
| 38 |
+
speed_metrics,
|
| 39 |
+
has_length,
|
| 40 |
+
EvalPrediction,
|
| 41 |
+
EvalLoopContainer,
|
| 42 |
+
PredictionOutput,
|
| 43 |
+
is_torch_xla_available,
|
| 44 |
+
denumpify_detensorize,
|
| 45 |
+
PredictionOutput,
|
| 46 |
+
EvalLoopOutput,
|
| 47 |
+
DistributedTensorGatherer,
|
| 48 |
+
SequentialDistributedSampler,
|
| 49 |
+
nested_concat,
|
| 50 |
+
)
|
| 51 |
+
from transformers.trainer_pt_utils import IterableDatasetShard
|
| 52 |
+
from transformers.trainer_callback import TrainerControl, TrainerState
|
| 53 |
+
|
| 54 |
+
from transformers.trainer_pt_utils import nested_detach, find_batch_size
|
| 55 |
+
from transformers.training_args import TrainingArguments
|
| 56 |
+
from trl import RewardTrainer
|
| 57 |
+
from hpsv3.utils.training_utils import get_peft_state_non_lora_maybe_zero_3
|
| 58 |
+
|
| 59 |
+
class Qwen2VLRewardModelBT(Qwen2VLForConditionalGeneration):
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
config,
|
| 63 |
+
output_dim=4,
|
| 64 |
+
reward_token="last",
|
| 65 |
+
special_token_ids=None,
|
| 66 |
+
rm_head_type="default",
|
| 67 |
+
rm_head_kwargs=None,
|
| 68 |
+
):
|
| 69 |
+
super().__init__(config)
|
| 70 |
+
# pdb.set_trace()
|
| 71 |
+
self.output_dim = output_dim
|
| 72 |
+
if rm_head_type == "default":
|
| 73 |
+
self.rm_head = nn.Linear(config.hidden_size, output_dim, bias=False)
|
| 74 |
+
elif rm_head_type == "ranknet":
|
| 75 |
+
if rm_head_kwargs is not None:
|
| 76 |
+
for layer in range(rm_head_kwargs.get("num_layers", 3)):
|
| 77 |
+
if layer == 0:
|
| 78 |
+
self.rm_head = nn.Sequential(
|
| 79 |
+
nn.Linear(config.hidden_size, rm_head_kwargs["hidden_size"]),
|
| 80 |
+
nn.ReLU(),
|
| 81 |
+
nn.Dropout(rm_head_kwargs.get("dropout", 0.1)),
|
| 82 |
+
)
|
| 83 |
+
elif layer < rm_head_kwargs.get("num_layers", 3) - 1:
|
| 84 |
+
self.rm_head.add_module(
|
| 85 |
+
f"layer_{layer}",
|
| 86 |
+
nn.Sequential(
|
| 87 |
+
nn.Linear(rm_head_kwargs["hidden_size"], rm_head_kwargs["hidden_size"]),
|
| 88 |
+
nn.ReLU(),
|
| 89 |
+
nn.Dropout(rm_head_kwargs.get("dropout", 0.1)),
|
| 90 |
+
),
|
| 91 |
+
)
|
| 92 |
+
else:
|
| 93 |
+
self.rm_head.add_module(
|
| 94 |
+
f"output_layer",
|
| 95 |
+
nn.Linear(rm_head_kwargs["hidden_size"], output_dim, bias=rm_head_kwargs.get("bias", False)),
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
else:
|
| 99 |
+
self.rm_head = nn.Sequential(
|
| 100 |
+
nn.Linear(config.hidden_size, 1024),
|
| 101 |
+
nn.ReLU(),
|
| 102 |
+
nn.Dropout(0.05),
|
| 103 |
+
nn.Linear(1024, 16),
|
| 104 |
+
nn.ReLU(),
|
| 105 |
+
nn.Linear(16, output_dim),
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
self.rm_head.to(torch.float32)
|
| 109 |
+
self.reward_token = reward_token
|
| 110 |
+
|
| 111 |
+
self.special_token_ids = special_token_ids
|
| 112 |
+
if self.special_token_ids is not None:
|
| 113 |
+
self.reward_token = "special"
|
| 114 |
+
|
| 115 |
+
def forward(
|
| 116 |
+
self,
|
| 117 |
+
input_ids: torch.LongTensor = None,
|
| 118 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 119 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 120 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 121 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 122 |
+
labels: Optional[torch.LongTensor] = None,
|
| 123 |
+
use_cache: Optional[bool] = None,
|
| 124 |
+
output_attentions: Optional[bool] = None,
|
| 125 |
+
output_hidden_states: Optional[bool] = None,
|
| 126 |
+
return_dict: Optional[bool] = None,
|
| 127 |
+
pixel_values: Optional[torch.Tensor] = None,
|
| 128 |
+
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
| 129 |
+
image_grid_thw: Optional[torch.LongTensor] = None,
|
| 130 |
+
video_grid_thw: Optional[torch.LongTensor] = None,
|
| 131 |
+
rope_deltas: Optional[torch.LongTensor] = None,
|
| 132 |
+
):
|
| 133 |
+
## modified from the origin class Qwen2VLForConditionalGeneration
|
| 134 |
+
output_attentions = (
|
| 135 |
+
output_attentions
|
| 136 |
+
if output_attentions is not None
|
| 137 |
+
else self.config.output_attentions
|
| 138 |
+
)
|
| 139 |
+
output_hidden_states = (
|
| 140 |
+
output_hidden_states
|
| 141 |
+
if output_hidden_states is not None
|
| 142 |
+
else self.config.output_hidden_states
|
| 143 |
+
)
|
| 144 |
+
return_dict = (
|
| 145 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
| 146 |
+
)
|
| 147 |
+
# pdb.set_trace()
|
| 148 |
+
if inputs_embeds is None:
|
| 149 |
+
inputs_embeds = self.model.embed_tokens(input_ids)
|
| 150 |
+
if pixel_values is not None:
|
| 151 |
+
pixel_values = pixel_values.type(self.visual.get_dtype())
|
| 152 |
+
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
| 153 |
+
image_mask = (
|
| 154 |
+
(input_ids == self.config.image_token_id)
|
| 155 |
+
.unsqueeze(-1)
|
| 156 |
+
.expand_as(inputs_embeds)
|
| 157 |
+
)
|
| 158 |
+
image_embeds = image_embeds.to(
|
| 159 |
+
inputs_embeds.device, inputs_embeds.dtype
|
| 160 |
+
)
|
| 161 |
+
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
| 162 |
+
|
| 163 |
+
if pixel_values_videos is not None:
|
| 164 |
+
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
|
| 165 |
+
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
| 166 |
+
video_mask = (
|
| 167 |
+
(input_ids == self.config.video_token_id)
|
| 168 |
+
.unsqueeze(-1)
|
| 169 |
+
.expand_as(inputs_embeds)
|
| 170 |
+
)
|
| 171 |
+
video_embeds = video_embeds.to(
|
| 172 |
+
inputs_embeds.device, inputs_embeds.dtype
|
| 173 |
+
)
|
| 174 |
+
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
| 175 |
+
|
| 176 |
+
if attention_mask is not None:
|
| 177 |
+
attention_mask = attention_mask.to(inputs_embeds.device)
|
| 178 |
+
|
| 179 |
+
outputs = self.model(
|
| 180 |
+
input_ids=None,
|
| 181 |
+
position_ids=position_ids,
|
| 182 |
+
attention_mask=attention_mask,
|
| 183 |
+
past_key_values=past_key_values,
|
| 184 |
+
inputs_embeds=inputs_embeds,
|
| 185 |
+
use_cache=use_cache,
|
| 186 |
+
output_attentions=output_attentions,
|
| 187 |
+
output_hidden_states=output_hidden_states,
|
| 188 |
+
return_dict=return_dict,
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
hidden_states = outputs[0] # [B, L, D]
|
| 192 |
+
with torch.autocast(device_type='cuda', dtype=torch.float32):
|
| 193 |
+
logits = self.rm_head(hidden_states) # [B, L, N]
|
| 194 |
+
|
| 195 |
+
if input_ids is not None:
|
| 196 |
+
batch_size = input_ids.shape[0]
|
| 197 |
+
else:
|
| 198 |
+
batch_size = inputs_embeds.shape[0]
|
| 199 |
+
|
| 200 |
+
## get sequence length
|
| 201 |
+
if self.config.pad_token_id is None and batch_size != 1:
|
| 202 |
+
raise ValueError(
|
| 203 |
+
"Cannot handle batch sizes > 1 if no padding token is defined."
|
| 204 |
+
)
|
| 205 |
+
if self.config.pad_token_id is None:
|
| 206 |
+
sequence_lengths = -1
|
| 207 |
+
else:
|
| 208 |
+
if input_ids is not None:
|
| 209 |
+
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
| 210 |
+
sequence_lengths = (
|
| 211 |
+
torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
| 212 |
+
)
|
| 213 |
+
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
| 214 |
+
sequence_lengths = sequence_lengths.to(logits.device)
|
| 215 |
+
else:
|
| 216 |
+
sequence_lengths = -1
|
| 217 |
+
|
| 218 |
+
## get the last token's logits
|
| 219 |
+
if self.reward_token == "last":
|
| 220 |
+
pooled_logits = logits[
|
| 221 |
+
torch.arange(batch_size, device=logits.device), sequence_lengths
|
| 222 |
+
]
|
| 223 |
+
elif self.reward_token == "mean":
|
| 224 |
+
## get the mean of all valid tokens' logits
|
| 225 |
+
valid_lengths = torch.clamp(sequence_lengths, min=0, max=logits.size(1) - 1)
|
| 226 |
+
pooled_logits = torch.stack(
|
| 227 |
+
[logits[i, : valid_lengths[i]].mean(dim=0) for i in range(batch_size)]
|
| 228 |
+
)
|
| 229 |
+
elif self.reward_token == "special":
|
| 230 |
+
# special_token_ids = self.tokenizer.convert_tokens_to_ids(self.special_tokens)
|
| 231 |
+
# create a mask for special tokens
|
| 232 |
+
special_token_mask = torch.zeros_like(input_ids, dtype=torch.bool)
|
| 233 |
+
for special_token_id in self.special_token_ids:
|
| 234 |
+
special_token_mask = special_token_mask | (
|
| 235 |
+
input_ids == special_token_id
|
| 236 |
+
)
|
| 237 |
+
pooled_logits = logits[special_token_mask, ...]
|
| 238 |
+
pooled_logits = pooled_logits.view(
|
| 239 |
+
batch_size, 1, -1
|
| 240 |
+
) # [B, 3, N] assert 3 attributes
|
| 241 |
+
pooled_logits = pooled_logits.view(batch_size, -1)
|
| 242 |
+
|
| 243 |
+
# pdb.set_trace()
|
| 244 |
+
else:
|
| 245 |
+
raise ValueError("Invalid reward_token")
|
| 246 |
+
|
| 247 |
+
return {"logits": pooled_logits}
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def _convert_A_B_to_chosen_rejected(
|
| 251 |
+
rewards_A,
|
| 252 |
+
rewards_B,
|
| 253 |
+
tied_threshold=None,
|
| 254 |
+
choice_dist=None,
|
| 255 |
+
):
|
| 256 |
+
"""
|
| 257 |
+
Inputs:
|
| 258 |
+
rewards_A: [B, 1]
|
| 259 |
+
rewards_B: [B, 1]
|
| 260 |
+
Outputs:
|
| 261 |
+
rewards_chosen: [B, 1]
|
| 262 |
+
rewards_rejected: [B, 1]
|
| 263 |
+
nontied_mask: [B, 1] (preference labels that is not tied)
|
| 264 |
+
"""
|
| 265 |
+
chosen_label = torch.ones_like(rewards_A, dtype=torch.int64).to(
|
| 266 |
+
rewards_A.device
|
| 267 |
+
) # [B, 1]
|
| 268 |
+
chosen_mask = chosen_label == 1
|
| 269 |
+
rejected_mask = chosen_label != 1
|
| 270 |
+
|
| 271 |
+
rewards_chosen = rewards_A
|
| 272 |
+
rewards_rejected = rewards_B
|
| 273 |
+
|
| 274 |
+
if tied_threshold is None:
|
| 275 |
+
nontied_mask = torch.ones_like(chosen_label, dtype=torch.float32).to(
|
| 276 |
+
rewards_A.device
|
| 277 |
+
)
|
| 278 |
+
else:
|
| 279 |
+
nontied_mask = (
|
| 280 |
+
torch.abs(
|
| 281 |
+
(choice_dist[:, 0] - choice_dist[:, 1]) / torch.sum(choice_dist, dim=-1)
|
| 282 |
+
)
|
| 283 |
+
> tied_threshold
|
| 284 |
+
)
|
| 285 |
+
print(nontied_mask)
|
| 286 |
+
return (
|
| 287 |
+
rewards_chosen,
|
| 288 |
+
rewards_rejected,
|
| 289 |
+
nontied_mask,
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
class PartialEmbeddingUpdateCallback(TrainerCallback):
|
| 294 |
+
"""
|
| 295 |
+
Callback to update the embedding of special tokens
|
| 296 |
+
Only the special tokens are updated, the rest of the embeddings are kept fixed
|
| 297 |
+
"""
|
| 298 |
+
|
| 299 |
+
def __init__(self, special_token_ids):
|
| 300 |
+
super().__init__()
|
| 301 |
+
self.special_token_ids = special_token_ids
|
| 302 |
+
self.orig_embeds_params = None
|
| 303 |
+
|
| 304 |
+
def on_train_begin(self, args, state, control, **kwargs):
|
| 305 |
+
model = kwargs.get("model")
|
| 306 |
+
self.orig_embeds_params = model.get_input_embeddings().weight.clone().detach()
|
| 307 |
+
|
| 308 |
+
def on_step_end(self, args, state, control, **kwargs):
|
| 309 |
+
# pdb.set_trace()
|
| 310 |
+
model = kwargs.get("model")
|
| 311 |
+
tokenizer = kwargs.get("tokenizer")
|
| 312 |
+
|
| 313 |
+
index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool)
|
| 314 |
+
index_no_updates[self.special_token_ids] = False
|
| 315 |
+
with torch.no_grad():
|
| 316 |
+
model.get_input_embeddings().weight[index_no_updates] = (
|
| 317 |
+
self.orig_embeds_params[index_no_updates]
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
class VLMRewardTrainer(RewardTrainer):
|
| 322 |
+
def __init__(self, loss_type="regular", loss_hyperparameters={}, tied_threshold=None,
|
| 323 |
+
visualization_steps=500, max_viz_samples=4, *args, **kwargs):
|
| 324 |
+
super(VLMRewardTrainer, self).__init__(*args, **kwargs)
|
| 325 |
+
self.loss_type = loss_type
|
| 326 |
+
self.tied_threshold = tied_threshold
|
| 327 |
+
self.rewards_chosen_accumulated = []
|
| 328 |
+
self.rewards_rejected_accumulated = []
|
| 329 |
+
self.loss_hyperparameters = loss_hyperparameters
|
| 330 |
+
self.visualization_steps = visualization_steps
|
| 331 |
+
self.max_viz_samples = max_viz_samples
|
| 332 |
+
|
| 333 |
+
def get_eval_dataloader(
|
| 334 |
+
self, eval_dataset: Optional[Union[str, Dataset]] = None
|
| 335 |
+
) -> DataLoader:
|
| 336 |
+
"""
|
| 337 |
+
Returns the evaluation [`~torch.utils.data.DataLoader`].
|
| 338 |
+
|
| 339 |
+
Subclass and override this method if you want to inject some custom behavior.
|
| 340 |
+
|
| 341 |
+
Args:
|
| 342 |
+
eval_dataset (`str` or `torch.utils.data.Dataset`, *optional*):
|
| 343 |
+
If a `str`, will use `self.eval_dataset[eval_dataset]` as the evaluation dataset. If a `Dataset`, will override `self.eval_dataset` and must implement `__len__`. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed.
|
| 344 |
+
"""
|
| 345 |
+
if eval_dataset is None and self.eval_dataset is None:
|
| 346 |
+
raise ValueError("Trainer: evaluation requires an eval_dataset.")
|
| 347 |
+
|
| 348 |
+
# If we have persistent workers, don't do a fork bomb especially as eval datasets
|
| 349 |
+
# don't change during training
|
| 350 |
+
dataloader_key = eval_dataset if isinstance(eval_dataset, str) else "eval"
|
| 351 |
+
if (
|
| 352 |
+
hasattr(self, "_eval_dataloaders")
|
| 353 |
+
and dataloader_key in self._eval_dataloaders
|
| 354 |
+
and self.args.dataloader_persistent_workers
|
| 355 |
+
):
|
| 356 |
+
return self.accelerator.prepare(self._eval_dataloaders[dataloader_key])
|
| 357 |
+
|
| 358 |
+
eval_dataset = (
|
| 359 |
+
self.eval_dataset[eval_dataset]
|
| 360 |
+
if isinstance(eval_dataset, str)
|
| 361 |
+
else eval_dataset if eval_dataset is not None else self.eval_dataset
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
data_collator = self.data_collator
|
| 365 |
+
|
| 366 |
+
if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
|
| 367 |
+
eval_dataset = self._remove_unused_columns(
|
| 368 |
+
eval_dataset, description="evaluation"
|
| 369 |
+
)
|
| 370 |
+
else:
|
| 371 |
+
data_collator = self._get_collator_with_removed_columns(
|
| 372 |
+
data_collator, description="evaluation"
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
dataloader_params = {
|
| 376 |
+
"batch_size": self.args.eval_batch_size,
|
| 377 |
+
"collate_fn": data_collator,
|
| 378 |
+
"num_workers": self.args.dataloader_num_workers,
|
| 379 |
+
"pin_memory": self.args.dataloader_pin_memory,
|
| 380 |
+
"persistent_workers": self.args.dataloader_persistent_workers,
|
| 381 |
+
}
|
| 382 |
+
|
| 383 |
+
if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
|
| 384 |
+
dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset)
|
| 385 |
+
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
| 386 |
+
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
| 387 |
+
|
| 388 |
+
# accelerator.free_memory() will destroy the references, so
|
| 389 |
+
# we need to store the non-prepared version
|
| 390 |
+
eval_dataloader = DataLoader(eval_dataset, **dataloader_params)
|
| 391 |
+
if self.args.dataloader_persistent_workers:
|
| 392 |
+
if hasattr(self, "_eval_dataloaders"):
|
| 393 |
+
self._eval_dataloaders[dataloader_key] = eval_dataloader
|
| 394 |
+
else:
|
| 395 |
+
self._eval_dataloaders = {dataloader_key: eval_dataloader}
|
| 396 |
+
|
| 397 |
+
return self.accelerator.prepare(eval_dataloader)
|
| 398 |
+
|
| 399 |
+
def create_optimizer(self):
|
| 400 |
+
"""
|
| 401 |
+
Setup the optimizer.
|
| 402 |
+
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
|
| 403 |
+
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
|
| 404 |
+
"""
|
| 405 |
+
if is_sagemaker_mp_enabled():
|
| 406 |
+
return super().create_optimizer()
|
| 407 |
+
|
| 408 |
+
opt_model = self.model
|
| 409 |
+
|
| 410 |
+
if self.optimizer is None:
|
| 411 |
+
decay_parameters = self.get_decay_parameter_names(opt_model)
|
| 412 |
+
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
| 413 |
+
lr_mapper = {}
|
| 414 |
+
visual_parameters = []
|
| 415 |
+
merger_parameters = []
|
| 416 |
+
rm_head_parameters = []
|
| 417 |
+
|
| 418 |
+
if self.args.vision_lr is not None:
|
| 419 |
+
lr_mapper["visual"] = self.args.vision_lr
|
| 420 |
+
visual_parameters = [
|
| 421 |
+
name
|
| 422 |
+
for name, _ in opt_model.named_parameters()
|
| 423 |
+
if "visual" in name and "merger" not in name
|
| 424 |
+
]
|
| 425 |
+
if self.args.merger_lr is not None:
|
| 426 |
+
lr_mapper["merger"] = self.args.merger_lr
|
| 427 |
+
merger_parameters = [
|
| 428 |
+
name for name, _ in opt_model.named_parameters() if "merger" in name
|
| 429 |
+
]
|
| 430 |
+
|
| 431 |
+
if self.args.rm_head_lr is not None:
|
| 432 |
+
lr_mapper["rm_head"] = self.args.rm_head_lr
|
| 433 |
+
rm_head_parameters = [
|
| 434 |
+
name for name, _ in opt_model.named_parameters() if "rm_head" in name
|
| 435 |
+
]
|
| 436 |
+
|
| 437 |
+
if len(lr_mapper) > 0:
|
| 438 |
+
special_lr_parameters = merger_parameters + visual_parameters + rm_head_parameters
|
| 439 |
+
|
| 440 |
+
optimizer_grouped_parameters = [
|
| 441 |
+
{
|
| 442 |
+
"params": [
|
| 443 |
+
p
|
| 444 |
+
for n, p in opt_model.named_parameters()
|
| 445 |
+
if (
|
| 446 |
+
n in decay_parameters
|
| 447 |
+
and n not in special_lr_parameters
|
| 448 |
+
and p.requires_grad
|
| 449 |
+
)
|
| 450 |
+
],
|
| 451 |
+
"weight_decay": self.args.weight_decay,
|
| 452 |
+
},
|
| 453 |
+
{
|
| 454 |
+
"params": [
|
| 455 |
+
p
|
| 456 |
+
for n, p in opt_model.named_parameters()
|
| 457 |
+
if (
|
| 458 |
+
n not in decay_parameters
|
| 459 |
+
and n not in special_lr_parameters
|
| 460 |
+
and p.requires_grad
|
| 461 |
+
)
|
| 462 |
+
],
|
| 463 |
+
"weight_decay": 0.0,
|
| 464 |
+
},
|
| 465 |
+
]
|
| 466 |
+
|
| 467 |
+
if visual_parameters:
|
| 468 |
+
optimizer_grouped_parameters.extend(
|
| 469 |
+
[
|
| 470 |
+
{
|
| 471 |
+
"params": [
|
| 472 |
+
p
|
| 473 |
+
for n, p in opt_model.named_parameters()
|
| 474 |
+
if (
|
| 475 |
+
n in decay_parameters
|
| 476 |
+
and n in visual_parameters
|
| 477 |
+
and p.requires_grad
|
| 478 |
+
)
|
| 479 |
+
],
|
| 480 |
+
"weight_decay": self.args.weight_decay,
|
| 481 |
+
"lr": self.args.vision_lr,
|
| 482 |
+
},
|
| 483 |
+
{
|
| 484 |
+
"params": [
|
| 485 |
+
p
|
| 486 |
+
for n, p in opt_model.named_parameters()
|
| 487 |
+
if (
|
| 488 |
+
n not in decay_parameters
|
| 489 |
+
and n in visual_parameters
|
| 490 |
+
and p.requires_grad
|
| 491 |
+
)
|
| 492 |
+
],
|
| 493 |
+
"weight_decay": 0.0,
|
| 494 |
+
"lr": self.args.vision_lr,
|
| 495 |
+
},
|
| 496 |
+
]
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
if merger_parameters:
|
| 500 |
+
optimizer_grouped_parameters.extend(
|
| 501 |
+
[
|
| 502 |
+
{
|
| 503 |
+
"params": [
|
| 504 |
+
p
|
| 505 |
+
for n, p in opt_model.named_parameters()
|
| 506 |
+
if (
|
| 507 |
+
n in decay_parameters
|
| 508 |
+
and n in merger_parameters
|
| 509 |
+
and p.requires_grad
|
| 510 |
+
)
|
| 511 |
+
],
|
| 512 |
+
"weight_decay": self.args.weight_decay,
|
| 513 |
+
"lr": self.args.merger_lr,
|
| 514 |
+
},
|
| 515 |
+
{
|
| 516 |
+
"params": [
|
| 517 |
+
p
|
| 518 |
+
for n, p in opt_model.named_parameters()
|
| 519 |
+
if (
|
| 520 |
+
n not in decay_parameters
|
| 521 |
+
and n in merger_parameters
|
| 522 |
+
and p.requires_grad
|
| 523 |
+
)
|
| 524 |
+
],
|
| 525 |
+
"weight_decay": 0.0,
|
| 526 |
+
"lr": self.args.merger_lr,
|
| 527 |
+
},
|
| 528 |
+
]
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
if rm_head_parameters:
|
| 532 |
+
optimizer_grouped_parameters.extend(
|
| 533 |
+
[
|
| 534 |
+
{
|
| 535 |
+
"params": [
|
| 536 |
+
p
|
| 537 |
+
for n, p in opt_model.named_parameters()
|
| 538 |
+
if (
|
| 539 |
+
n in decay_parameters
|
| 540 |
+
and n in rm_head_parameters
|
| 541 |
+
and p.requires_grad
|
| 542 |
+
)
|
| 543 |
+
],
|
| 544 |
+
"weight_decay": self.args.weight_decay,
|
| 545 |
+
"lr": self.args.rm_head_lr,
|
| 546 |
+
},
|
| 547 |
+
{
|
| 548 |
+
"params": [
|
| 549 |
+
p
|
| 550 |
+
for n, p in opt_model.named_parameters()
|
| 551 |
+
if (
|
| 552 |
+
n not in decay_parameters
|
| 553 |
+
and n in rm_head_parameters
|
| 554 |
+
and p.requires_grad
|
| 555 |
+
)
|
| 556 |
+
],
|
| 557 |
+
"weight_decay": 0.0,
|
| 558 |
+
"lr": self.args.rm_head_lr,
|
| 559 |
+
},
|
| 560 |
+
]
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
+
else:
|
| 564 |
+
optimizer_grouped_parameters = [
|
| 565 |
+
{
|
| 566 |
+
"params": [
|
| 567 |
+
p
|
| 568 |
+
for n, p in opt_model.named_parameters()
|
| 569 |
+
if (n in decay_parameters and p.requires_grad)
|
| 570 |
+
],
|
| 571 |
+
"weight_decay": self.args.weight_decay,
|
| 572 |
+
},
|
| 573 |
+
{
|
| 574 |
+
"params": [
|
| 575 |
+
p
|
| 576 |
+
for n, p in opt_model.named_parameters()
|
| 577 |
+
if (n not in decay_parameters and p.requires_grad)
|
| 578 |
+
],
|
| 579 |
+
"weight_decay": 0.0,
|
| 580 |
+
},
|
| 581 |
+
]
|
| 582 |
+
|
| 583 |
+
if self.model.special_token_ids:
|
| 584 |
+
special_token_embeddings = opt_model.get_input_embeddings().weight
|
| 585 |
+
|
| 586 |
+
special_token_embeddings.requires_grad = True
|
| 587 |
+
|
| 588 |
+
optimizer_grouped_parameters.extend(
|
| 589 |
+
[
|
| 590 |
+
{
|
| 591 |
+
# "params": [p for n, p in opt_model.get_input_embeddings().named_parameters() if (p.requires_grad)],
|
| 592 |
+
"params": [special_token_embeddings],
|
| 593 |
+
"lr": self.args.special_token_lr,
|
| 594 |
+
"weight_decay": 0.0,
|
| 595 |
+
},
|
| 596 |
+
]
|
| 597 |
+
)
|
| 598 |
+
|
| 599 |
+
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(
|
| 600 |
+
self.args, opt_model
|
| 601 |
+
)
|
| 602 |
+
|
| 603 |
+
self.optimizer = optimizer_cls(
|
| 604 |
+
optimizer_grouped_parameters, **optimizer_kwargs
|
| 605 |
+
)
|
| 606 |
+
|
| 607 |
+
return self.optimizer
|
| 608 |
+
|
| 609 |
+
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
| 610 |
+
rewards_A = model(return_dict=True, **inputs["batch_1"])["logits"]
|
| 611 |
+
rewards_B = model(return_dict=True, **inputs["batch_2"])["logits"]
|
| 612 |
+
|
| 613 |
+
# Log to TensorBoard for visualization
|
| 614 |
+
if (hasattr(self.state, 'global_step') and
|
| 615 |
+
self.state.global_step % self.visualization_steps == 0 and
|
| 616 |
+
self.state.global_step > 0):
|
| 617 |
+
# Pass the original inputs which should contain the text prompts
|
| 618 |
+
self._log_training_visualization(inputs, rewards_A, rewards_B)
|
| 619 |
+
|
| 620 |
+
# calculate loss, optionally modulate with margin
|
| 621 |
+
# get chosen and rejected rewards from the chosen label
|
| 622 |
+
(
|
| 623 |
+
rewards_chosen,
|
| 624 |
+
rewards_rejected,
|
| 625 |
+
nontied_mask,
|
| 626 |
+
) = _convert_A_B_to_chosen_rejected(
|
| 627 |
+
rewards_A,
|
| 628 |
+
rewards_B,
|
| 629 |
+
tied_threshold=self.tied_threshold,
|
| 630 |
+
choice_dist=inputs["choice_dist"],
|
| 631 |
+
)
|
| 632 |
+
|
| 633 |
+
loss_dict = {}
|
| 634 |
+
|
| 635 |
+
if self.loss_type == "bt":
|
| 636 |
+
# Bradley-Terry model
|
| 637 |
+
loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected)
|
| 638 |
+
out_mask = nontied_mask
|
| 639 |
+
loss = loss * out_mask
|
| 640 |
+
loss = loss.mean()
|
| 641 |
+
elif self.loss_type == "likelihood_displacement":
|
| 642 |
+
# Bradley-Terry model
|
| 643 |
+
loss = -nn.functional.logsigmoid(rewards_chosen - self.loss_hyperparameters['tau'] * rewards_rejected)
|
| 644 |
+
out_mask = nontied_mask
|
| 645 |
+
loss = loss * out_mask
|
| 646 |
+
loss = loss.mean()
|
| 647 |
+
|
| 648 |
+
elif self.loss_type == "constant_margin":
|
| 649 |
+
# Bradley-Terry model with constant margin
|
| 650 |
+
loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - 0.57)
|
| 651 |
+
out_mask = nontied_mask
|
| 652 |
+
loss = loss * out_mask
|
| 653 |
+
loss = loss.mean()
|
| 654 |
+
elif self.loss_type == "btt":
|
| 655 |
+
# Bradley-Terry-With-Ties model
|
| 656 |
+
k = 5.0
|
| 657 |
+
log_k = math.log(k)
|
| 658 |
+
log_k2_sub_1 = math.log(k**2 - 1)
|
| 659 |
+
bt_loss = -nn.functional.logsigmoid(
|
| 660 |
+
rewards_chosen - rewards_rejected - log_k
|
| 661 |
+
)
|
| 662 |
+
same_loss = (
|
| 663 |
+
-nn.functional.logsigmoid(rewards_chosen - rewards_rejected - log_k)
|
| 664 |
+
- nn.functional.logsigmoid(rewards_rejected - rewards_chosen - log_k)
|
| 665 |
+
- log_k2_sub_1
|
| 666 |
+
)
|
| 667 |
+
loss = bt_loss * nontied_mask.float() + same_loss * (
|
| 668 |
+
1 - nontied_mask.float()
|
| 669 |
+
)
|
| 670 |
+
out_mask = torch.ones_like(nontied_mask, dtype=torch.float32).to(
|
| 671 |
+
rewards_A.device
|
| 672 |
+
) # [B, 1]
|
| 673 |
+
loss = loss * out_mask
|
| 674 |
+
|
| 675 |
+
loss = loss.mean()
|
| 676 |
+
elif self.loss_type == "hpsv2":
|
| 677 |
+
device = rewards_A.device
|
| 678 |
+
rewards = torch.nn.functional.softmax(
|
| 679 |
+
torch.cat([rewards_A, rewards_B], dim=-1), dim=-1
|
| 680 |
+
)
|
| 681 |
+
text_0_logits, text_1_logits = rewards[:, 0], rewards[:, 1]
|
| 682 |
+
label_0, label_1 = torch.ones_like(text_0_logits), torch.zeros_like(
|
| 683 |
+
text_0_logits
|
| 684 |
+
)
|
| 685 |
+
|
| 686 |
+
text_logits = torch.stack([text_0_logits, text_1_logits], dim=-1)
|
| 687 |
+
text_0_labels = torch.zeros(
|
| 688 |
+
text_logits.shape[0], device=device, dtype=torch.long
|
| 689 |
+
)
|
| 690 |
+
text_1_labels = text_0_labels + 1
|
| 691 |
+
|
| 692 |
+
text_0_loss = torch.nn.functional.cross_entropy(
|
| 693 |
+
text_logits, text_0_labels, reduction="none"
|
| 694 |
+
)
|
| 695 |
+
text_1_loss = torch.nn.functional.cross_entropy(
|
| 696 |
+
text_logits, text_1_labels, reduction="none"
|
| 697 |
+
)
|
| 698 |
+
|
| 699 |
+
loss = label_0 * text_0_loss + label_1 * text_1_loss
|
| 700 |
+
|
| 701 |
+
# absolute_example_weight = 1 / num_per_prompt
|
| 702 |
+
# denominator = absolute_example_weight.sum()
|
| 703 |
+
# weight_per_example = absolute_example_weight / denominator
|
| 704 |
+
# text_loss *= weight_per_example
|
| 705 |
+
loss = loss.sum()
|
| 706 |
+
elif self.loss_type == "uncertainty":
|
| 707 |
+
batch_size = rewards_A.shape[0]
|
| 708 |
+
mean_chosen = rewards_A[:, 0]
|
| 709 |
+
mean_rejected = rewards_B[:, 0]
|
| 710 |
+
sigma_chosen = torch.exp(rewards_A[:, 1])
|
| 711 |
+
sigma_rejected = torch.exp(rewards_B[:, 1])
|
| 712 |
+
|
| 713 |
+
mean_z = mean_chosen - mean_rejected
|
| 714 |
+
sigma_z = torch.sqrt(sigma_chosen**2 + sigma_rejected**2)
|
| 715 |
+
|
| 716 |
+
z_samples = torch.randn(batch_size, 1000).to(sigma_z.device).to(
|
| 717 |
+
torch.float16
|
| 718 |
+
) * sigma_z.unsqueeze(1).repeat(1, 1000) + mean_z.unsqueeze(1).repeat(
|
| 719 |
+
1, 1000
|
| 720 |
+
)
|
| 721 |
+
loss = -torch.nn.functional.logsigmoid(z_samples).mean()
|
| 722 |
+
else:
|
| 723 |
+
raise NotImplementedError(f"Loss type {self.loss_type} not implemented.")
|
| 724 |
+
|
| 725 |
+
loss_dict.update({"loss": loss.item()})
|
| 726 |
+
|
| 727 |
+
if return_outputs:
|
| 728 |
+
## return rewards_A/B instead of chosen/rejected
|
| 729 |
+
## easier to calculate metrics for multi-attribute
|
| 730 |
+
return loss, {
|
| 731 |
+
"rewards_A": rewards_A,
|
| 732 |
+
"rewards_B": rewards_B,
|
| 733 |
+
}
|
| 734 |
+
return loss
|
| 735 |
+
|
| 736 |
+
def prediction_step(
|
| 737 |
+
self,
|
| 738 |
+
model,
|
| 739 |
+
inputs,
|
| 740 |
+
prediction_loss_only,
|
| 741 |
+
ignore_keys=None,
|
| 742 |
+
):
|
| 743 |
+
model.eval()
|
| 744 |
+
inputs = self._prepare_inputs(inputs)
|
| 745 |
+
if ignore_keys is None:
|
| 746 |
+
if hasattr(self.model, "config"):
|
| 747 |
+
ignore_keys = getattr(
|
| 748 |
+
self.model.config, "keys_to_ignore_at_inference", []
|
| 749 |
+
)
|
| 750 |
+
else:
|
| 751 |
+
ignore_keys = []
|
| 752 |
+
|
| 753 |
+
with torch.no_grad():
|
| 754 |
+
loss, logits_dict = self.compute_loss(model, inputs, return_outputs=True)
|
| 755 |
+
|
| 756 |
+
if prediction_loss_only:
|
| 757 |
+
return (loss, None, None)
|
| 758 |
+
loss = loss.detach()
|
| 759 |
+
logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys)
|
| 760 |
+
logits = nested_detach(logits)
|
| 761 |
+
if self.loss_type != "uncertainty":
|
| 762 |
+
logits = torch.cat(logits, dim=1) # [B, 2]
|
| 763 |
+
else:
|
| 764 |
+
logits = torch.cat([p[:, [0]] for p in logits], dim=1)
|
| 765 |
+
|
| 766 |
+
labels = torch.ones((logits.shape[0], 1)).to(logits.device)
|
| 767 |
+
|
| 768 |
+
return loss, logits, labels
|
| 769 |
+
|
| 770 |
+
def _log_training_visualization(self, inputs, rewards_A, rewards_B):
|
| 771 |
+
"""Log training samples and predictions to TensorBoard"""
|
| 772 |
+
try:
|
| 773 |
+
# Get tensorboard writer from trainer
|
| 774 |
+
writer = None
|
| 775 |
+
if hasattr(self, 'log_metrics'):
|
| 776 |
+
# Try to get the writer from the logger
|
| 777 |
+
if hasattr(self.args, 'report_to') and 'tensorboard' in self.args.report_to:
|
| 778 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 779 |
+
if not hasattr(self, '_tb_writer'):
|
| 780 |
+
self._tb_writer = SummaryWriter(log_dir=self.args.logging_dir)
|
| 781 |
+
writer = self._tb_writer
|
| 782 |
+
|
| 783 |
+
if writer is None:
|
| 784 |
+
return
|
| 785 |
+
|
| 786 |
+
step = self.state.global_step
|
| 787 |
+
batch_size = min(len(rewards_A), self.max_viz_samples)
|
| 788 |
+
|
| 789 |
+
# Log scalar metrics
|
| 790 |
+
for i in range(batch_size):
|
| 791 |
+
score_A = rewards_A[i].float().detach().cpu().numpy()
|
| 792 |
+
score_B = rewards_B[i].float().detach().cpu().numpy()
|
| 793 |
+
|
| 794 |
+
# Convert to float for logging
|
| 795 |
+
score_A_val = float(score_A.mean()) if score_A.ndim > 0 else float(score_A)
|
| 796 |
+
score_B_val = float(score_B.mean()) if score_B.ndim > 0 else float(score_B)
|
| 797 |
+
score_diff = score_A_val - score_B_val
|
| 798 |
+
|
| 799 |
+
writer.add_scalar(f'train_viz/sample_{i}/score_A', score_A_val, step)
|
| 800 |
+
writer.add_scalar(f'train_viz/sample_{i}/score_B', score_B_val, step)
|
| 801 |
+
writer.add_scalar(f'train_viz/sample_{i}/score_diff', score_diff, step)
|
| 802 |
+
|
| 803 |
+
try:
|
| 804 |
+
# Get image data from inputs
|
| 805 |
+
image_A = inputs['image_1'][i] if 'image_1' in inputs else None
|
| 806 |
+
image_B = inputs['image_2'][i] if 'image_2' in inputs else None
|
| 807 |
+
|
| 808 |
+
# Get prompt text from the original batch (now properly stored)
|
| 809 |
+
prompt_A = inputs.get('text_1', ['Unknown prompt'])[i] if 'text_1' in inputs else 'Unknown prompt'
|
| 810 |
+
|
| 811 |
+
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(12, 8))
|
| 812 |
+
fig.text(0.05, 0.05, f'Prompt:\n{prompt_A[:200]}{"..." if len(prompt_A) > 200 else ""}',
|
| 813 |
+
ha='left', va='bottom', fontsize=8, wrap=True,
|
| 814 |
+
bbox=dict(boxstyle="round,pad=0.3", facecolor="lightblue", alpha=0.7))
|
| 815 |
+
img_A_np = np.array(image_A)
|
| 816 |
+
if img_A_np.ndim == 3 and img_A_np.shape[0] == 3: # CHW format
|
| 817 |
+
img_A_np = np.transpose(img_A_np, (1, 2, 0))
|
| 818 |
+
img_A_np = np.clip(img_A_np, 0, 1) # Ensure values are in [0,1]
|
| 819 |
+
axes[0].imshow(img_A_np)
|
| 820 |
+
axes[0].set_title(f'Image A - Score: {score_A_val:.3f}')
|
| 821 |
+
axes[0].axis('off')
|
| 822 |
+
|
| 823 |
+
img_B_np = np.array(image_B)
|
| 824 |
+
if img_B_np.ndim == 3 and img_B_np.shape[0] == 3: # CHW format
|
| 825 |
+
img_B_np = np.transpose(img_B_np, (1, 2, 0))
|
| 826 |
+
img_B_np = np.clip(img_B_np, 0, 1) # Ensure values are in [0,1]
|
| 827 |
+
axes[1].imshow(img_B_np)
|
| 828 |
+
|
| 829 |
+
axes[1].set_title(f'Image B - Score: {score_B_val:.3f}')
|
| 830 |
+
axes[1].axis('off')
|
| 831 |
+
|
| 832 |
+
# Add prediction info
|
| 833 |
+
winner = "A" if score_diff > 0 else "B"
|
| 834 |
+
plt.suptitle(f'Step {step} - Sample {i} | Predicted Winner: Image {winner} | Diff: {score_diff:.3f}', fontsize=14)
|
| 835 |
+
plt.tight_layout()
|
| 836 |
+
|
| 837 |
+
# Log figure to tensorboard
|
| 838 |
+
writer.add_figure(f'train_viz/sample_{i}_comparison', fig, step)
|
| 839 |
+
plt.close(fig)
|
| 840 |
+
except Exception as viz_error:
|
| 841 |
+
print(f"Warning: Could not extract images for visualization: {viz_error}")
|
| 842 |
+
continue
|
| 843 |
+
|
| 844 |
+
# Log aggregate statistics
|
| 845 |
+
all_scores_A = rewards_A.float().detach().cpu().numpy()
|
| 846 |
+
all_scores_B = rewards_B.float().detach().cpu().numpy()
|
| 847 |
+
|
| 848 |
+
writer.add_histogram('train_viz/all_scores_A', all_scores_A, step)
|
| 849 |
+
writer.add_histogram('train_viz/all_scores_B', all_scores_B, step)
|
| 850 |
+
writer.add_scalar('train_viz/mean_score_A', float(all_scores_A.mean()), step)
|
| 851 |
+
writer.add_scalar('train_viz/mean_score_B', float(all_scores_B.mean()), step)
|
| 852 |
+
writer.add_scalar('train_viz/mean_score_diff', float((all_scores_A - all_scores_B).mean()), step)
|
| 853 |
+
|
| 854 |
+
except Exception as e:
|
| 855 |
+
print(f"Error in training visualization: {e}")
|
| 856 |
+
|
| 857 |
+
def _save_checkpoint(self, model, trial, metrics=None):
|
| 858 |
+
|
| 859 |
+
if isinstance(self.model, PeftModel):
|
| 860 |
+
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
| 861 |
+
|
| 862 |
+
if self.hp_search_backend is None and trial is None:
|
| 863 |
+
self.store_flos()
|
| 864 |
+
|
| 865 |
+
run_dir = self._get_output_dir(trial=trial)
|
| 866 |
+
output_dir = os.path.join(run_dir, checkpoint_folder)
|
| 867 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 868 |
+
|
| 869 |
+
# TODO: Just Temp
|
| 870 |
+
self.save_model(output_dir, _internal_call=True)
|
| 871 |
+
# pdb.set_trace()
|
| 872 |
+
|
| 873 |
+
if not self.args.save_full_model:
|
| 874 |
+
non_lora_weights = get_peft_state_non_lora_maybe_zero_3(
|
| 875 |
+
self.model.named_parameters(), require_grad_only=True
|
| 876 |
+
)
|
| 877 |
+
torch.save(
|
| 878 |
+
non_lora_weights,
|
| 879 |
+
os.path.join(output_dir, "non_lora_state_dict.pth"),
|
| 880 |
+
)
|
| 881 |
+
# safetensors.torch.save(non_lora_weights, os.path.join(output_dir, "non_lora_model.safetensors"))
|
| 882 |
+
|
| 883 |
+
if not self.args.save_only_model:
|
| 884 |
+
# Save optimizer and scheduler
|
| 885 |
+
self._save_optimizer_and_scheduler(output_dir)
|
| 886 |
+
# Save RNG state
|
| 887 |
+
self._save_rng_state(output_dir)
|
| 888 |
+
|
| 889 |
+
else:
|
| 890 |
+
super(RewardTrainer, self)._save_checkpoint(model, trial, metrics)
|
| 891 |
+
|
| 892 |
+
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
| 893 |
+
# If we are executing this function, we are the process zero, so we don't check for that.
|
| 894 |
+
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
| 895 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 896 |
+
logger.info(f"Saving model checkpoint to {output_dir}")
|
| 897 |
+
# pdb.set_trace()
|
| 898 |
+
|
| 899 |
+
supported_classes = (
|
| 900 |
+
(PreTrainedModel,)
|
| 901 |
+
if not is_peft_available()
|
| 902 |
+
else (PreTrainedModel, PeftModel)
|
| 903 |
+
)
|
| 904 |
+
# Save a trained model and configuration using `save_pretrained()`.
|
| 905 |
+
# They can then be reloaded using `from_pretrained()`
|
| 906 |
+
if not isinstance(self.model, supported_classes):
|
| 907 |
+
if state_dict is None:
|
| 908 |
+
state_dict = self.model.state_dict()
|
| 909 |
+
|
| 910 |
+
if isinstance(self.accelerator.unwrap_model(self.model), supported_classes):
|
| 911 |
+
self.accelerator.unwrap_model(self.model).save_pretrained(
|
| 912 |
+
output_dir,
|
| 913 |
+
state_dict=state_dict,
|
| 914 |
+
safe_serialization=self.args.save_safetensors,
|
| 915 |
+
)
|
| 916 |
+
else:
|
| 917 |
+
logger.info(
|
| 918 |
+
"Trainer.model is not a `PreTrainedModel`, only saving its state dict."
|
| 919 |
+
)
|
| 920 |
+
if self.args.save_safetensors:
|
| 921 |
+
safetensors.torch.save_file(
|
| 922 |
+
state_dict,
|
| 923 |
+
os.path.join(output_dir, SAFE_WEIGHTS_NAME),
|
| 924 |
+
metadata={"format": "pt"},
|
| 925 |
+
)
|
| 926 |
+
else:
|
| 927 |
+
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
| 928 |
+
else:
|
| 929 |
+
if not self.args.save_full_model:
|
| 930 |
+
state_dict = {k: v for k, v in state_dict.items() if "wte" not in k}
|
| 931 |
+
self.model.save_pretrained(
|
| 932 |
+
output_dir,
|
| 933 |
+
state_dict=state_dict,
|
| 934 |
+
safe_serialization=self.args.save_safetensors,
|
| 935 |
+
)
|
| 936 |
+
else:
|
| 937 |
+
torch.save(state_dict, os.path.join(output_dir, "model.pth"))
|
| 938 |
+
|
| 939 |
+
if self.tokenizer is not None:
|
| 940 |
+
os.makedirs(os.path.join(output_dir, "tokenizer"), exist_ok=True)
|
| 941 |
+
self.tokenizer.save_pretrained(os.path.join(output_dir, "tokenizer"))
|
| 942 |
+
|
| 943 |
+
# Good practice: save your training arguments together with the trained model
|
| 944 |
+
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
| 945 |
+
# pdb.set_trace()
|
| 946 |
+
|
| 947 |
+
|
| 948 |
+
def compute_multi_attr_accuracy(eval_pred, metainfo_idxs=None) -> Dict[str, float]:
|
| 949 |
+
predictions, labels = eval_pred
|
| 950 |
+
metrics = {}
|
| 951 |
+
|
| 952 |
+
pred_curr = predictions
|
| 953 |
+
label_curr = labels.squeeze(1)
|
| 954 |
+
total_count = np.sum(label_curr != 0)
|
| 955 |
+
|
| 956 |
+
rewards_chosen = pred_curr[:, 0]
|
| 957 |
+
rewards_rejected = pred_curr[:, 1]
|
| 958 |
+
|
| 959 |
+
rewards_chosen_avg = np.sum(rewards_chosen) / total_count
|
| 960 |
+
rewards_rejected_avg = np.sum(rewards_rejected) / total_count
|
| 961 |
+
|
| 962 |
+
accuracy = np.sum(rewards_chosen > rewards_rejected) / total_count
|
| 963 |
+
|
| 964 |
+
metrics.update(
|
| 965 |
+
{
|
| 966 |
+
f"Acc": accuracy,
|
| 967 |
+
f"R_chosen_avg": rewards_chosen_avg,
|
| 968 |
+
f"R_rejected_avg": rewards_rejected_avg,
|
| 969 |
+
}
|
| 970 |
+
)
|
| 971 |
+
return metrics
|
hpsv3/model/test_differentiable.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import os
|
| 3 |
+
from collections.abc import Mapping
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import huggingface_hub
|
| 8 |
+
from hpsv3.dataset.utils import process_vision_info
|
| 9 |
+
from hpsv3.dataset.data_collator_qwen import prompt_with_special_token, prompt_without_special_token, INSTRUCTION
|
| 10 |
+
from hpsv3.utils.parser import ModelConfig, PEFTLoraConfig, TrainingConfig, DataConfig, parse_args_with_yaml
|
| 11 |
+
from hpsv3.train import create_model_and_processor
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
_MODEL_CONFIG_PATH = Path(__file__).parent / f"config/"
|
| 15 |
+
|
| 16 |
+
class HPSv3RewardInferencer():
|
| 17 |
+
def __init__(self, config_path=None, checkpoint_path=None, device='cuda', differentiable=False):
|
| 18 |
+
if config_path is None:
|
| 19 |
+
config_path = os.path.join(_MODEL_CONFIG_PATH, 'HPSv3_7B.yaml')
|
| 20 |
+
|
| 21 |
+
if checkpoint_path is None:
|
| 22 |
+
checkpoint_path = os.path.join(huggingface_hub.hf_hub_download("xilanhua12138/HPSv3"), 'model.pth')
|
| 23 |
+
|
| 24 |
+
(data_config, training_args, model_config, peft_lora_config), config_path = (
|
| 25 |
+
parse_args_with_yaml(
|
| 26 |
+
(DataConfig, TrainingConfig, ModelConfig, PEFTLoraConfig), config_path, is_train=False
|
| 27 |
+
)
|
| 28 |
+
)
|
| 29 |
+
training_args.output_dir = os.path.join(
|
| 30 |
+
training_args.output_dir, config_path.split("/")[-1].split(".")[0]
|
| 31 |
+
)
|
| 32 |
+
model, processor, peft_config = create_model_and_processor(
|
| 33 |
+
model_config=model_config,
|
| 34 |
+
peft_lora_config=peft_lora_config,
|
| 35 |
+
training_args=training_args,
|
| 36 |
+
differentiable=differentiable,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
self.device = device
|
| 40 |
+
self.use_special_tokens = model_config.use_special_tokens
|
| 41 |
+
|
| 42 |
+
state_dict = torch.load(checkpoint_path , map_location="cpu")
|
| 43 |
+
if "model" in state_dict:
|
| 44 |
+
state_dict = state_dict["model"]
|
| 45 |
+
model.load_state_dict(state_dict, strict=False)
|
| 46 |
+
model.eval()
|
| 47 |
+
|
| 48 |
+
self.model = model
|
| 49 |
+
self.processor = processor
|
| 50 |
+
|
| 51 |
+
self.model.to(self.device)
|
| 52 |
+
self.data_config = data_config
|
| 53 |
+
|
| 54 |
+
def _pad_sequence(self, sequences, attention_mask, max_len, padding_side='right'):
|
| 55 |
+
"""
|
| 56 |
+
Pad the sequences to the maximum length.
|
| 57 |
+
"""
|
| 58 |
+
assert padding_side in ['right', 'left']
|
| 59 |
+
if sequences.shape[1] >= max_len:
|
| 60 |
+
return sequences, attention_mask
|
| 61 |
+
|
| 62 |
+
pad_len = max_len - sequences.shape[1]
|
| 63 |
+
padding = (0, pad_len) if padding_side == 'right' else (pad_len, 0)
|
| 64 |
+
|
| 65 |
+
sequences_padded = torch.nn.functional.pad(sequences, padding, 'constant', self.processor.tokenizer.pad_token_id)
|
| 66 |
+
attention_mask_padded = torch.nn.functional.pad(attention_mask, padding, 'constant', 0)
|
| 67 |
+
|
| 68 |
+
return sequences_padded, attention_mask_padded
|
| 69 |
+
|
| 70 |
+
def _prepare_input(self, data):
|
| 71 |
+
"""
|
| 72 |
+
Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and
|
| 73 |
+
handling potential state.
|
| 74 |
+
"""
|
| 75 |
+
if isinstance(data, Mapping):
|
| 76 |
+
return type(data)({k: self._prepare_input(v) for k, v in data.items()})
|
| 77 |
+
elif isinstance(data, (tuple, list)):
|
| 78 |
+
return type(data)(self._prepare_input(v) for v in data)
|
| 79 |
+
elif isinstance(data, torch.Tensor):
|
| 80 |
+
kwargs = {"device": self.device}
|
| 81 |
+
## TODO: Maybe need to add dtype
|
| 82 |
+
# if self.is_deepspeed_enabled and (torch.is_floating_point(data) or torch.is_complex(data)):
|
| 83 |
+
# # NLP models inputs are int/uint and those get adjusted to the right dtype of the
|
| 84 |
+
# # embedding. Other models such as wav2vec2's inputs are already float and thus
|
| 85 |
+
# # may need special handling to match the dtypes of the model
|
| 86 |
+
# kwargs.update({"dtype": self.accelerator.state.deepspeed_plugin.hf_ds_config.dtype()})
|
| 87 |
+
return data.to(**kwargs)
|
| 88 |
+
return data
|
| 89 |
+
|
| 90 |
+
def _prepare_inputs(self, inputs):
|
| 91 |
+
"""
|
| 92 |
+
Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and
|
| 93 |
+
handling potential state.
|
| 94 |
+
"""
|
| 95 |
+
inputs = self._prepare_input(inputs)
|
| 96 |
+
if len(inputs) == 0:
|
| 97 |
+
raise ValueError
|
| 98 |
+
return inputs
|
| 99 |
+
|
| 100 |
+
def prepare_batch(self, image_paths, prompts):
|
| 101 |
+
max_pixels = 256 * 28 * 28
|
| 102 |
+
min_pixels = 256 * 28 * 28
|
| 103 |
+
message_list = []
|
| 104 |
+
for text, image in zip(prompts, image_paths):
|
| 105 |
+
out_message = [
|
| 106 |
+
{
|
| 107 |
+
"role": "user",
|
| 108 |
+
"content": [
|
| 109 |
+
{
|
| 110 |
+
"type": "image",
|
| 111 |
+
"image": image,
|
| 112 |
+
"min_pixels": max_pixels,
|
| 113 |
+
"max_pixels": max_pixels,
|
| 114 |
+
},
|
| 115 |
+
{
|
| 116 |
+
"type": "text",
|
| 117 |
+
"text": (
|
| 118 |
+
INSTRUCTION.format(text_prompt=text)
|
| 119 |
+
+ prompt_with_special_token
|
| 120 |
+
if self.use_special_tokens
|
| 121 |
+
else prompt_without_special_token
|
| 122 |
+
),
|
| 123 |
+
},
|
| 124 |
+
],
|
| 125 |
+
}
|
| 126 |
+
]
|
| 127 |
+
|
| 128 |
+
message_list.append(out_message)
|
| 129 |
+
|
| 130 |
+
image_inputs, _ = process_vision_info(message_list)
|
| 131 |
+
|
| 132 |
+
batch = self.processor(
|
| 133 |
+
text=self.processor.apply_chat_template(message_list, tokenize=False, add_generation_prompt=True),
|
| 134 |
+
images=image_inputs,
|
| 135 |
+
padding=True,
|
| 136 |
+
return_tensors="pt",
|
| 137 |
+
videos_kwargs={"do_rescale": True},
|
| 138 |
+
)
|
| 139 |
+
batch = self._prepare_inputs(batch)
|
| 140 |
+
return batch
|
| 141 |
+
|
| 142 |
+
def reward(self, image_paths, prompts):
|
| 143 |
+
|
| 144 |
+
batch = self.prepare_batch(image_paths, prompts)
|
| 145 |
+
rewards = self.model(
|
| 146 |
+
return_dict=True,
|
| 147 |
+
**batch
|
| 148 |
+
)["logits"]
|
| 149 |
+
|
| 150 |
+
return rewards
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
if __name__ == "__main__":
|
| 154 |
+
config_path = '/preflab/shuiyunhao/tasks/HPSv3_official/hpsv3/config/HPSv3_7B.yaml'
|
| 155 |
+
checkpoint_path = '/preflab/shuiyunhao/tasks/HPSv3_official/checkpoints/HPSv3_7B/model.pth'
|
| 156 |
+
device = 'cuda'
|
| 157 |
+
dtype = torch.bfloat16
|
| 158 |
+
inferencer = HPSv3RewardInferencer(config_path, checkpoint_path, differentiable=True, device=device)
|
| 159 |
+
|
| 160 |
+
images = [
|
| 161 |
+
torch.from_numpy(np.array(Image.open("assets/example1.png"))),
|
| 162 |
+
torch.from_numpy(np.array(Image.open("assets/example2.png")))
|
| 163 |
+
]
|
| 164 |
+
prompts = [
|
| 165 |
+
"cute chibi anime cartoon fox, smiling wagging tail with a small cartoon heart above sticker",
|
| 166 |
+
"cute chibi anime cartoon fox, smiling wagging tail with a small cartoon heart above sticker"
|
| 167 |
+
]
|
| 168 |
+
rewards = inferencer.reward(images, prompts)
|
| 169 |
+
print(rewards[0][0].item()) # miu and sigma. we select miu as the final output
|
| 170 |
+
print(rewards[1][0].item())
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
loss = rewards[0][0]
|
| 174 |
+
print(f"Loss value: {loss.item()}")
|
| 175 |
+
print(f"Loss requires_grad: {loss.requires_grad}")
|
| 176 |
+
|
| 177 |
+
if loss.requires_grad:
|
| 178 |
+
loss.backward(retain_graph=True)
|
| 179 |
+
|
| 180 |
+
has_grad = False
|
| 181 |
+
for name, param in inferencer.model.named_parameters():
|
| 182 |
+
if param.grad is not None:
|
| 183 |
+
has_grad = True
|
| 184 |
+
print(f"参数 {name} 有梯度: {param.grad.norm().item():.6f}")
|
| 185 |
+
break
|
| 186 |
+
|
| 187 |
+
if has_grad:
|
| 188 |
+
print("has grad")
|
| 189 |
+
else:
|
| 190 |
+
print("NO GRAD!!")
|
| 191 |
+
else:
|
| 192 |
+
print("Final loss does not require gradient computation.")
|
| 193 |
+
|
| 194 |
+
# Compare
|
| 195 |
+
img_pil = Image.open("assets/example1.png").convert('RGB')
|
| 196 |
+
non_diff_result = inferencer.processor.preprocess(
|
| 197 |
+
images=[img_pil],
|
| 198 |
+
return_tensors="pt"
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
img_tensor = torch.from_numpy(np.array(img_pil)).float().permute(2, 0, 1) / 255.0
|
| 202 |
+
diff_result = inferencer.processor.preprocess_tensor(img_tensor)
|
| 203 |
+
|
| 204 |
+
if non_diff_result['pixel_values'].shape == diff_result['pixel_values'].shape:
|
| 205 |
+
diff = torch.abs(non_diff_result['pixel_values'] - diff_result['pixel_values']).mean()
|
| 206 |
+
if diff.item() < 0.01:
|
| 207 |
+
print("Right")
|
| 208 |
+
else:
|
| 209 |
+
print("Different outputs")
|
| 210 |
+
else:
|
| 211 |
+
print("Shape mismatch between non-differentiable and differentiable outputs.")
|
| 212 |
+
|
hpsv3/train.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import fire
|
| 4 |
+
from dataclasses import asdict
|
| 5 |
+
from functools import partial
|
| 6 |
+
import torch
|
| 7 |
+
import torch.distributed as dist
|
| 8 |
+
from hpsv3.model.qwen2vl_trainer import (
|
| 9 |
+
Qwen2VLRewardModelBT,
|
| 10 |
+
VLMRewardTrainer,
|
| 11 |
+
compute_multi_attr_accuracy,
|
| 12 |
+
PartialEmbeddingUpdateCallback,
|
| 13 |
+
)
|
| 14 |
+
from hpsv3.dataset.pairwise_dataset import PairwiseOriginalDataset
|
| 15 |
+
from hpsv3.dataset.data_collator_qwen import QWen2VLDataCollator
|
| 16 |
+
from hpsv3.utils.parser import ModelConfig, PEFTLoraConfig, TrainingConfig, DataConfig
|
| 17 |
+
from hpsv3.utils.training_utils import load_model_from_checkpoint, find_target_linear_names
|
| 18 |
+
from hpsv3.utils.parser import parse_args_with_yaml
|
| 19 |
+
from transformers import AutoProcessor
|
| 20 |
+
from peft import LoraConfig, get_peft_model
|
| 21 |
+
from trl import get_kbit_device_map, get_quantization_config
|
| 22 |
+
from hpsv3.model.differentiable_image_processor import Qwen2VLImageProcessor
|
| 23 |
+
try:
|
| 24 |
+
import flash_attn
|
| 25 |
+
except ImportError:
|
| 26 |
+
flash_attn = None
|
| 27 |
+
print("Flash Attention is not installed. Falling to SDPA.")
|
| 28 |
+
|
| 29 |
+
def create_model_and_processor(
|
| 30 |
+
model_config,
|
| 31 |
+
peft_lora_config,
|
| 32 |
+
training_args,
|
| 33 |
+
cache_dir=None,
|
| 34 |
+
differentiable=False,
|
| 35 |
+
):
|
| 36 |
+
# create model
|
| 37 |
+
torch_dtype = (
|
| 38 |
+
model_config.torch_dtype
|
| 39 |
+
if model_config.torch_dtype in ["auto", None]
|
| 40 |
+
else getattr(torch, model_config.torch_dtype)
|
| 41 |
+
)
|
| 42 |
+
quantization_config = get_quantization_config(model_config)
|
| 43 |
+
model_kwargs = dict(
|
| 44 |
+
revision=model_config.model_revision,
|
| 45 |
+
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
| 46 |
+
quantization_config=quantization_config,
|
| 47 |
+
use_cache=False
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# create processor and set padding
|
| 51 |
+
|
| 52 |
+
processor = AutoProcessor.from_pretrained(
|
| 53 |
+
model_config.model_name_or_path, padding_side="right", cache_dir=cache_dir
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
if differentiable:
|
| 57 |
+
processor.image_processor = Qwen2VLImageProcessor()
|
| 58 |
+
|
| 59 |
+
special_token_ids = None
|
| 60 |
+
if model_config.use_special_tokens:
|
| 61 |
+
special_tokens = ["<|Reward|>"]
|
| 62 |
+
processor.tokenizer.add_special_tokens(
|
| 63 |
+
{"additional_special_tokens": special_tokens}
|
| 64 |
+
)
|
| 65 |
+
special_token_ids = processor.tokenizer.convert_tokens_to_ids(special_tokens)
|
| 66 |
+
|
| 67 |
+
model = Qwen2VLRewardModelBT.from_pretrained(
|
| 68 |
+
model_config.model_name_or_path,
|
| 69 |
+
output_dim=model_config.output_dim,
|
| 70 |
+
reward_token=model_config.reward_token,
|
| 71 |
+
special_token_ids=special_token_ids,
|
| 72 |
+
torch_dtype=torch_dtype,
|
| 73 |
+
attn_implementation=(
|
| 74 |
+
"flash_attention_2" if not training_args.disable_flash_attn2 and flash_attn is not None else "sdpa"
|
| 75 |
+
),
|
| 76 |
+
cache_dir=cache_dir,
|
| 77 |
+
rm_head_type=model_config.rm_head_type,
|
| 78 |
+
rm_head_kwargs=model_config.rm_head_kwargs,
|
| 79 |
+
**model_kwargs,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
if model_config.use_special_tokens:
|
| 83 |
+
model.resize_token_embeddings(len(processor.tokenizer))
|
| 84 |
+
|
| 85 |
+
if training_args.bf16:
|
| 86 |
+
model.to(torch.bfloat16)
|
| 87 |
+
if training_args.fp16:
|
| 88 |
+
model.to(torch.float16)
|
| 89 |
+
|
| 90 |
+
model.rm_head.to(torch.float32)
|
| 91 |
+
|
| 92 |
+
# create lora and peft model
|
| 93 |
+
if peft_lora_config.lora_enable:
|
| 94 |
+
target_modules = find_target_linear_names(
|
| 95 |
+
model,
|
| 96 |
+
num_lora_modules=peft_lora_config.num_lora_modules,
|
| 97 |
+
lora_namespan_exclude=peft_lora_config.lora_namespan_exclude,
|
| 98 |
+
)
|
| 99 |
+
peft_config = LoraConfig(
|
| 100 |
+
target_modules=target_modules,
|
| 101 |
+
r=peft_lora_config.lora_r,
|
| 102 |
+
lora_alpha=peft_lora_config.lora_alpha,
|
| 103 |
+
lora_dropout=peft_lora_config.lora_dropout,
|
| 104 |
+
task_type=peft_lora_config.lora_task_type,
|
| 105 |
+
use_rslora=peft_lora_config.use_rslora,
|
| 106 |
+
bias="none",
|
| 107 |
+
modules_to_save=peft_lora_config.lora_modules_to_save,
|
| 108 |
+
)
|
| 109 |
+
model = get_peft_model(model, peft_config)
|
| 110 |
+
else:
|
| 111 |
+
peft_config = None
|
| 112 |
+
|
| 113 |
+
model.config.tokenizer_padding_side = processor.tokenizer.padding_side
|
| 114 |
+
model.config.pad_token_id = processor.tokenizer.pad_token_id
|
| 115 |
+
|
| 116 |
+
return model, processor, peft_config
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def save_configs_to_json(data_config, training_args, model_config, peft_lora_config):
|
| 120 |
+
"""
|
| 121 |
+
Save all configurations to a JSON file.
|
| 122 |
+
"""
|
| 123 |
+
config_dict = {
|
| 124 |
+
"data_config": asdict(data_config),
|
| 125 |
+
"training_args": asdict(training_args),
|
| 126 |
+
"model_config": asdict(model_config),
|
| 127 |
+
"peft_lora_config": asdict(peft_lora_config),
|
| 128 |
+
}
|
| 129 |
+
# del information about local device
|
| 130 |
+
del config_dict["training_args"]["local_rank"]
|
| 131 |
+
del config_dict["training_args"]["_n_gpu"]
|
| 132 |
+
|
| 133 |
+
save_path = os.path.join(training_args.output_dir, "model_config.json")
|
| 134 |
+
|
| 135 |
+
os.makedirs(training_args.output_dir, exist_ok=True)
|
| 136 |
+
print(training_args.output_dir)
|
| 137 |
+
|
| 138 |
+
with open(save_path, "w") as f:
|
| 139 |
+
json.dump(config_dict, f, indent=4)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def set_requires_grad(parameters, requires_grad):
|
| 143 |
+
for p in parameters:
|
| 144 |
+
p.requires_grad = requires_grad
|
| 145 |
+
|
| 146 |
+
def train(config, local_rank=0, debug=False):
|
| 147 |
+
|
| 148 |
+
## ===> Step 1: Parse arguments
|
| 149 |
+
(data_config, training_args, model_config, peft_lora_config), config_path = (
|
| 150 |
+
parse_args_with_yaml(
|
| 151 |
+
(DataConfig, TrainingConfig, ModelConfig, PEFTLoraConfig), config, is_train=True
|
| 152 |
+
)
|
| 153 |
+
)
|
| 154 |
+
training_args.output_dir = os.path.join(
|
| 155 |
+
training_args.output_dir, config.split("/")[-1].split(".")[0]
|
| 156 |
+
)
|
| 157 |
+
training_args.logging_dir = training_args.output_dir
|
| 158 |
+
# check valid (lora config)
|
| 159 |
+
assert not (
|
| 160 |
+
peft_lora_config.lora_enable and model_config.freeze_llm
|
| 161 |
+
), "When using LoRA, the LLM should not be frozen. If you want to freeze the LLM, please disable LoRA."
|
| 162 |
+
if not peft_lora_config.lora_enable:
|
| 163 |
+
assert (
|
| 164 |
+
not peft_lora_config.vision_lora
|
| 165 |
+
), "Error: model_config.lora_enable is not enabled, but model_config.vision_lora is enabled."
|
| 166 |
+
else:
|
| 167 |
+
if peft_lora_config.lora_namespan_exclude is None:
|
| 168 |
+
peft_lora_config.lora_namespan_exclude = []
|
| 169 |
+
if not peft_lora_config.vision_lora:
|
| 170 |
+
peft_lora_config.lora_namespan_exclude += ["visual"]
|
| 171 |
+
|
| 172 |
+
## ===> Step 2: Load model and configure
|
| 173 |
+
model, processor, peft_config = create_model_and_processor(
|
| 174 |
+
model_config=model_config,
|
| 175 |
+
peft_lora_config=peft_lora_config,
|
| 176 |
+
training_args=training_args,
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
## load model
|
| 180 |
+
if training_args.load_from_pretrained is not None:
|
| 181 |
+
model, checkpoint_step = load_model_from_checkpoint(
|
| 182 |
+
model,
|
| 183 |
+
training_args.load_from_pretrained,
|
| 184 |
+
training_args.load_from_pretrained_step,
|
| 185 |
+
)
|
| 186 |
+
model.train()
|
| 187 |
+
|
| 188 |
+
if peft_lora_config.lora_enable:
|
| 189 |
+
model_to_configure = model.model
|
| 190 |
+
else:
|
| 191 |
+
model_to_configure = model
|
| 192 |
+
# set requires_grad for LLM
|
| 193 |
+
set_requires_grad(
|
| 194 |
+
model_to_configure.model.parameters(), not model_config.freeze_llm
|
| 195 |
+
)
|
| 196 |
+
set_requires_grad(model_to_configure.model.embed_tokens.parameters(), False)
|
| 197 |
+
if not peft_lora_config.vision_lora:
|
| 198 |
+
# set requires_grad for visual encoder and merger
|
| 199 |
+
set_requires_grad(
|
| 200 |
+
model_to_configure.visual.parameters(), not model_config.freeze_vision_tower
|
| 201 |
+
)
|
| 202 |
+
set_requires_grad(
|
| 203 |
+
model_to_configure.visual.merger.parameters(), model_config.tune_merger
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
if model_config.trainable_visual_layers: # This is inverse order to index of model.visual.blocks, set -1 to unfreeze all layers
|
| 207 |
+
assert model_config.trainable_visual_layers <= len(model_to_configure.visual.blocks), "trainable_visual_layers should be less than or equal to the number of visual blocks"
|
| 208 |
+
freeze_layer_num = len(model_to_configure.visual.blocks) - model_config.trainable_visual_layers if model_config.trainable_visual_layers > 0 else 0
|
| 209 |
+
for index, layer in enumerate(model_to_configure.visual.blocks):
|
| 210 |
+
if index < freeze_layer_num:
|
| 211 |
+
set_requires_grad(layer.parameters(), False)
|
| 212 |
+
else:
|
| 213 |
+
set_requires_grad(layer.parameters(), True)
|
| 214 |
+
|
| 215 |
+
# set requires_grad for regression head
|
| 216 |
+
set_requires_grad(model_to_configure.rm_head.parameters(), True)
|
| 217 |
+
|
| 218 |
+
## ===> Step 3: Load Dataset and configure
|
| 219 |
+
train_dataset = PairwiseOriginalDataset(
|
| 220 |
+
data_config.train_json_list,
|
| 221 |
+
data_config.soft_label,
|
| 222 |
+
data_config.confidence_threshold,
|
| 223 |
+
)
|
| 224 |
+
test_set_dict = {}
|
| 225 |
+
for item in data_config.test_json_list:
|
| 226 |
+
test_set_dict[item[0]] = PairwiseOriginalDataset(
|
| 227 |
+
item[1],
|
| 228 |
+
data_config.soft_label,
|
| 229 |
+
data_config.confidence_threshold,
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
print(f"===> Selected {len(train_dataset)} samples for training.")
|
| 233 |
+
for key, value in test_set_dict.items():
|
| 234 |
+
print(f"===> Selected {len(value)} samples for {key} testing.")
|
| 235 |
+
|
| 236 |
+
num_gpu = int(os.environ.get("WORLD_SIZE", 1))
|
| 237 |
+
data_collator = QWen2VLDataCollator(
|
| 238 |
+
processor,
|
| 239 |
+
max_pixels=data_config.max_pixels,
|
| 240 |
+
min_pixels=data_config.min_pixels,
|
| 241 |
+
with_instruction=data_config.with_instruction,
|
| 242 |
+
use_special_tokens=model_config.use_special_tokens,
|
| 243 |
+
)
|
| 244 |
+
compute_metrics = partial(compute_multi_attr_accuracy)
|
| 245 |
+
|
| 246 |
+
actual_batch_size = (
|
| 247 |
+
training_args.per_device_train_batch_size
|
| 248 |
+
* training_args.gradient_accumulation_steps
|
| 249 |
+
* num_gpu
|
| 250 |
+
)
|
| 251 |
+
total_steps = (
|
| 252 |
+
training_args.num_train_epochs * len(train_dataset) // actual_batch_size
|
| 253 |
+
)
|
| 254 |
+
if training_args.save_epochs is not None:
|
| 255 |
+
training_args.save_steps = round(
|
| 256 |
+
training_args.save_epochs * len(train_dataset) / actual_batch_size
|
| 257 |
+
)
|
| 258 |
+
if training_args.eval_epochs is not None:
|
| 259 |
+
training_args.eval_steps = round(
|
| 260 |
+
training_args.eval_epochs * len(train_dataset) / actual_batch_size
|
| 261 |
+
)
|
| 262 |
+
if training_args.logging_epochs is not None:
|
| 263 |
+
training_args.logging_steps = round(
|
| 264 |
+
training_args.logging_epochs * len(train_dataset) / actual_batch_size
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
if training_args.local_rank == -1 or training_args.local_rank == 0:
|
| 268 |
+
print(f"===> Using {num_gpu} GPUs.")
|
| 269 |
+
print(f"===> Total Batch Size: {actual_batch_size}")
|
| 270 |
+
print(f"===> Training Epochs: {training_args.num_train_epochs}")
|
| 271 |
+
print(f"===> Total Steps: {total_steps}")
|
| 272 |
+
print(f"===> Save Steps: {training_args.save_steps}")
|
| 273 |
+
print(f"===> Eval Steps: {training_args.eval_steps}")
|
| 274 |
+
print(f"===> Logging Steps: {training_args.logging_steps}")
|
| 275 |
+
|
| 276 |
+
## ===> Step 4: Save configs for re-check
|
| 277 |
+
if training_args.local_rank == -1 or training_args.local_rank == 0:
|
| 278 |
+
save_configs_to_json(data_config, training_args, model_config, peft_lora_config)
|
| 279 |
+
|
| 280 |
+
print(train_dataset)
|
| 281 |
+
## ===> Step 5: Start Training!
|
| 282 |
+
|
| 283 |
+
special_token_ids = model.special_token_ids
|
| 284 |
+
callbacks = []
|
| 285 |
+
if special_token_ids is not None:
|
| 286 |
+
callbacks.append(PartialEmbeddingUpdateCallback(special_token_ids))
|
| 287 |
+
|
| 288 |
+
trainer = VLMRewardTrainer(
|
| 289 |
+
model=model,
|
| 290 |
+
compute_metrics=compute_metrics,
|
| 291 |
+
data_collator=data_collator,
|
| 292 |
+
args=training_args,
|
| 293 |
+
train_dataset=train_dataset,
|
| 294 |
+
eval_dataset=(test_set_dict if training_args.conduct_eval else None),
|
| 295 |
+
peft_config=peft_config,
|
| 296 |
+
callbacks=callbacks,
|
| 297 |
+
loss_type=model_config.loss_type,
|
| 298 |
+
loss_hyperparameters=model_config.loss_hyperparameters,
|
| 299 |
+
tokenizer=processor.tokenizer,
|
| 300 |
+
tied_threshold=data_config.tied_threshold,
|
| 301 |
+
visualization_steps=training_args.visualization_steps,
|
| 302 |
+
max_viz_samples=training_args.max_viz_samples,
|
| 303 |
+
)
|
| 304 |
+
trainer.train()
|
| 305 |
+
|
| 306 |
+
if training_args.local_rank == -1 or training_args.local_rank == 0:
|
| 307 |
+
model_state_dict = model.state_dict()
|
| 308 |
+
torch.save(
|
| 309 |
+
model_state_dict, os.path.join(training_args.output_dir, "final_model.pth")
|
| 310 |
+
)
|
| 311 |
+
model.config.save_pretrained(training_args.output_dir)
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
if __name__ == "__main__":
|
| 315 |
+
fire.Fire(train)
|
hpsv3/utils/parser.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import yaml
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Any, Optional, Union, Tuple, List, Literal
|
| 5 |
+
from omegaconf import OmegaConf
|
| 6 |
+
from transformers import HfArgumentParser
|
| 7 |
+
from dataclasses import dataclass, field
|
| 8 |
+
from transformers import TrainingArguments
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class DataConfig:
|
| 12 |
+
train_json_list: List[str] = field(default_factory=lambda: ["/path/to/dataset/meta_data.json"])
|
| 13 |
+
val_json_list: List[str] = field(default_factory=lambda: ["/path/to/dataset/meta_data.json"])
|
| 14 |
+
test_json_list: List[str] = field(default_factory=lambda: ["/path/to/dataset/meta_data.json"])
|
| 15 |
+
soft_label: bool = False
|
| 16 |
+
confidence_threshold: Optional[float] = None
|
| 17 |
+
max_pixels: Optional[int] = 256 * 28 * 28 # Default max pixels
|
| 18 |
+
min_pixels: Optional[int] = 256 * 28 * 28
|
| 19 |
+
with_instruction: bool = True
|
| 20 |
+
tied_threshold: Optional[float] = None
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class TrainingConfig(TrainingArguments):
|
| 24 |
+
max_grad_norm: Optional[float] = 1.0
|
| 25 |
+
dataset_num_proc: Optional[int] = None
|
| 26 |
+
center_rewards_coefficient: Optional[float] = None
|
| 27 |
+
disable_flash_attn2: bool = field(default=False)
|
| 28 |
+
disable_dropout: bool = field(default=False)
|
| 29 |
+
|
| 30 |
+
vision_lr: Optional[float] = None
|
| 31 |
+
merger_lr: Optional[float] = None
|
| 32 |
+
rm_head_lr: Optional[float] = None
|
| 33 |
+
special_token_lr: Optional[float] = None
|
| 34 |
+
|
| 35 |
+
conduct_eval: Optional[bool] = True
|
| 36 |
+
load_from_pretrained: str = None
|
| 37 |
+
load_from_pretrained_step: int = None
|
| 38 |
+
logging_epochs: Optional[float] = None
|
| 39 |
+
eval_epochs: Optional[float] = None
|
| 40 |
+
save_epochs: Optional[float] = None
|
| 41 |
+
remove_unused_columns: Optional[bool] = False
|
| 42 |
+
|
| 43 |
+
save_full_model: Optional[bool] = False
|
| 44 |
+
|
| 45 |
+
# Visualization parameters
|
| 46 |
+
visualization_steps: Optional[int] = 100
|
| 47 |
+
max_viz_samples: Optional[int] = 4
|
| 48 |
+
|
| 49 |
+
@dataclass
|
| 50 |
+
class PEFTLoraConfig:
|
| 51 |
+
lora_enable: bool = False
|
| 52 |
+
vision_lora: bool = False
|
| 53 |
+
lora_r: int = 16
|
| 54 |
+
lora_alpha: int = 32
|
| 55 |
+
lora_dropout: float = 0.05
|
| 56 |
+
lora_target_modules: Optional[List[str]] = None
|
| 57 |
+
lora_namespan_exclude: Optional[List[str]] = None
|
| 58 |
+
lora_modules_to_save: Optional[List[str]] = None
|
| 59 |
+
lora_task_type: str = "CAUSAL_LM"
|
| 60 |
+
use_rslora: bool = False
|
| 61 |
+
num_lora_modules: int = -1
|
| 62 |
+
|
| 63 |
+
def __post_init__(self):
|
| 64 |
+
if (
|
| 65 |
+
isinstance(self.lora_target_modules, list)
|
| 66 |
+
and len(self.lora_target_modules) == 1
|
| 67 |
+
):
|
| 68 |
+
self.lora_target_modules = self.lora_target_modules[0]
|
| 69 |
+
|
| 70 |
+
if (
|
| 71 |
+
isinstance(self.lora_namespan_exclude, list)
|
| 72 |
+
and len(self.lora_namespan_exclude) == 1
|
| 73 |
+
):
|
| 74 |
+
self.lora_namespan_exclude = self.lora_namespan_exclude[0]
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@dataclass
|
| 78 |
+
class ModelConfig:
|
| 79 |
+
model_name_or_path: Optional[str] = None
|
| 80 |
+
model_revision: str = "main"
|
| 81 |
+
rm_head_type: str = "default"
|
| 82 |
+
rm_head_kwargs: Optional[dict] = None
|
| 83 |
+
output_dim: int = 1
|
| 84 |
+
|
| 85 |
+
use_special_tokens: bool = False
|
| 86 |
+
|
| 87 |
+
freeze_vision_tower: bool = field(default=False)
|
| 88 |
+
freeze_llm: bool = field(default=False)
|
| 89 |
+
tune_merger: bool = field(default=False)
|
| 90 |
+
trainable_visual_layers: Optional[int] = -1
|
| 91 |
+
|
| 92 |
+
torch_dtype: Optional[Literal["auto", "bfloat16", "float16", "float32"]] = None
|
| 93 |
+
trust_remote_code: bool = False
|
| 94 |
+
attn_implementation: Optional[str] = None
|
| 95 |
+
load_in_8bit: bool = False
|
| 96 |
+
load_in_4bit: bool = False
|
| 97 |
+
bnb_4bit_quant_type: Literal["fp4", "nf4"] = "nf4"
|
| 98 |
+
use_bnb_nested_quant: bool = False
|
| 99 |
+
reward_token: Literal["last", "mean", "special"] = "last"
|
| 100 |
+
loss_type: Literal["bt", "reg", "btt", "margin", "constant_margin", "scaled"] = (
|
| 101 |
+
"regular"
|
| 102 |
+
)
|
| 103 |
+
loss_hyperparameters: dict = field(default_factory=lambda: {})
|
| 104 |
+
checkpoint_path: Optional[str] = None
|
| 105 |
+
|
| 106 |
+
def __post_init__(self):
|
| 107 |
+
if self.load_in_8bit and self.load_in_4bit:
|
| 108 |
+
raise ValueError("You can't use 8 bit and 4 bit precision at the same time")
|
| 109 |
+
|
| 110 |
+
# if isinstance(self.lora_target_modules, list) and len(self.lora_target_modules) == 1:
|
| 111 |
+
# self.lora_target_modules = self.lora_target_modules[0]
|
| 112 |
+
|
| 113 |
+
# if isinstance(self.lora_namespan_exclude, list) and len(self.lora_namespan_exclude) == 1:
|
| 114 |
+
# self.lora_namespan_exclude = self.lora_namespan_exclude[0]
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
########## Functions for get trainable modules' parameters ##########
|
| 118 |
+
|
| 119 |
+
def parse_args_with_yaml(
|
| 120 |
+
dataclass_types: Tuple[type, ...],
|
| 121 |
+
config_path: str = None,
|
| 122 |
+
allow_extra_keys: bool = True,
|
| 123 |
+
is_train: bool = True,
|
| 124 |
+
) -> Tuple[Any, ...]:
|
| 125 |
+
"""
|
| 126 |
+
Parse arguments using HfArgumentParser with OmegaConf for YAML support.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
dataclass_types: Tuple of dataclass types for HfArgumentParser
|
| 130 |
+
args: Optional arguments (if None, will read from sys.argv)
|
| 131 |
+
allow_extra_keys: Whether to allow extra keys in config
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
Tuple of parsed dataclass instances
|
| 135 |
+
"""
|
| 136 |
+
# Read arguments from command line or provided args
|
| 137 |
+
# Load YAML config and merge with command line overrides
|
| 138 |
+
args = OmegaConf.to_container(OmegaConf.load(config_path))
|
| 139 |
+
if not is_train:
|
| 140 |
+
args.pop('deepspeed', None)
|
| 141 |
+
|
| 142 |
+
# Parse with HfArgumentParser
|
| 143 |
+
parser = HfArgumentParser(dataclass_types)
|
| 144 |
+
return parser.parse_dict(args, allow_extra_keys=allow_extra_keys), config_path
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
if __name__ == "__main__":
|
| 148 |
+
data_config, training_args, model_config, peft_lora_config = parse_args_with_yaml(
|
| 149 |
+
(DataConfig, TrainingConfig, ModelConfig, PEFTLoraConfig)
|
| 150 |
+
)
|
hpsv3/utils/training_utils.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import os
|
| 3 |
+
import glob
|
| 4 |
+
import safetensors
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def maybe_zero_3(param, ignore_status=False, name=None):
|
| 8 |
+
from deepspeed import zero
|
| 9 |
+
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
|
| 10 |
+
|
| 11 |
+
if hasattr(param, "ds_id"):
|
| 12 |
+
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
|
| 13 |
+
if not ignore_status:
|
| 14 |
+
print(
|
| 15 |
+
f"Parameter {name} is not available in ZeRO-3, please check the ZeRO-3 status."
|
| 16 |
+
)
|
| 17 |
+
with zero.GatheredParameters([param]):
|
| 18 |
+
param = param.data.detach().cpu().clone()
|
| 19 |
+
else:
|
| 20 |
+
param = param.detach().cpu().clone()
|
| 21 |
+
return param
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# Borrowed from peft.utils.get_peft_model_state_dict
|
| 25 |
+
def get_peft_state_maybe_zero_3(named_params, bias):
|
| 26 |
+
if bias == "none":
|
| 27 |
+
to_return = {k: t for k, t in named_params if "lora_" in k}
|
| 28 |
+
elif bias == "all":
|
| 29 |
+
to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
|
| 30 |
+
elif bias == "lora_only":
|
| 31 |
+
to_return = {}
|
| 32 |
+
maybe_lora_bias = {}
|
| 33 |
+
lora_bias_names = set()
|
| 34 |
+
for k, t in named_params:
|
| 35 |
+
if "lora_" in k:
|
| 36 |
+
to_return[k] = t
|
| 37 |
+
bias_name = k.split("lora_")[0] + "bias"
|
| 38 |
+
lora_bias_names.add(bias_name)
|
| 39 |
+
elif "bias" in k:
|
| 40 |
+
maybe_lora_bias[k] = t
|
| 41 |
+
for k, t in maybe_lora_bias:
|
| 42 |
+
if bias_name in lora_bias_names:
|
| 43 |
+
to_return[bias_name] = t
|
| 44 |
+
else:
|
| 45 |
+
raise NotImplementedError
|
| 46 |
+
to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
|
| 47 |
+
return to_return
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
|
| 51 |
+
to_return = {k: t for k, t in named_params if "lora_" not in k}
|
| 52 |
+
if require_grad_only:
|
| 53 |
+
to_return = {k: t for k, t in to_return.items() if t.requires_grad}
|
| 54 |
+
to_return = {
|
| 55 |
+
k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()
|
| 56 |
+
}
|
| 57 |
+
return to_return
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _insert_adapter_name_into_state_dict(
|
| 61 |
+
state_dict: dict[str, torch.Tensor], adapter_name: str, parameter_prefix: str
|
| 62 |
+
) -> dict[str, torch.Tensor]:
|
| 63 |
+
"""Utility function to remap the state_dict keys to fit the PEFT model by inserting the adapter name."""
|
| 64 |
+
peft_model_state_dict = {}
|
| 65 |
+
for key, val in state_dict.items():
|
| 66 |
+
if parameter_prefix in key:
|
| 67 |
+
suffix = key.split(parameter_prefix)[1]
|
| 68 |
+
if "." in suffix:
|
| 69 |
+
suffix_to_replace = ".".join(suffix.split(".")[1:])
|
| 70 |
+
key = key.replace(
|
| 71 |
+
suffix_to_replace, f"{adapter_name}.{suffix_to_replace}"
|
| 72 |
+
)
|
| 73 |
+
else:
|
| 74 |
+
key = f"{key}.{adapter_name}"
|
| 75 |
+
peft_model_state_dict[key] = val
|
| 76 |
+
else:
|
| 77 |
+
peft_model_state_dict[key] = val
|
| 78 |
+
return peft_model_state_dict
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def save_video(tensor, path):
|
| 82 |
+
from torchvision.io import write_video
|
| 83 |
+
|
| 84 |
+
tensor = tensor * 255.0
|
| 85 |
+
tensor = tensor.permute(0, 2, 3, 1)
|
| 86 |
+
tensor = tensor.clamp(0, 255).byte()
|
| 87 |
+
write_video(path, tensor, 4, video_codec="h264")
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def load_model_from_checkpoint(model, checkpoint_dir, checkpoint_step):
|
| 91 |
+
checkpoint_paths = glob.glob(os.path.join(checkpoint_dir, "checkpoint-*"))
|
| 92 |
+
checkpoint_paths.sort(key=lambda x: int(x.split("-")[-1]), reverse=True)
|
| 93 |
+
|
| 94 |
+
if checkpoint_step is None or checkpoint_step == -1:
|
| 95 |
+
# get the latest checkpoint
|
| 96 |
+
checkpoint_path = checkpoint_paths[0]
|
| 97 |
+
print(
|
| 98 |
+
f"===> Checkpoint step is not provided, using the latest checkpoint: {checkpoint_path}"
|
| 99 |
+
)
|
| 100 |
+
else:
|
| 101 |
+
checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint-{checkpoint_step}")
|
| 102 |
+
if checkpoint_path not in checkpoint_paths:
|
| 103 |
+
checkpoint_path = checkpoint_paths[0]
|
| 104 |
+
print(
|
| 105 |
+
f"===> Checkpoint step {checkpoint_step} not found, using the latest checkpoint: {checkpoint_path}"
|
| 106 |
+
)
|
| 107 |
+
else:
|
| 108 |
+
print(
|
| 109 |
+
f"===> Checkpoint step {checkpoint_step} found, using the specified checkpoint: {checkpoint_path}"
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
checkpoint_step = checkpoint_path.split("checkpoint-")[-1].split("/")[0]
|
| 113 |
+
|
| 114 |
+
full_ckpt = os.path.join(checkpoint_path, "model.pth")
|
| 115 |
+
lora_ckpt = os.path.join(checkpoint_path, "adapter_model.safetensors")
|
| 116 |
+
non_lora_ckpt = os.path.join(checkpoint_path, "non_lora_state_dict.pth")
|
| 117 |
+
if os.path.exists(full_ckpt):
|
| 118 |
+
model_state_dict = torch.load(full_ckpt, map_location="cpu")
|
| 119 |
+
model.load_state_dict(model_state_dict)
|
| 120 |
+
else:
|
| 121 |
+
lora_state_dict = safetensors.torch.load_file(lora_ckpt)
|
| 122 |
+
non_lora_state_dict = torch.load(non_lora_ckpt, map_location="cpu")
|
| 123 |
+
|
| 124 |
+
lora_state_dict = _insert_adapter_name_into_state_dict(
|
| 125 |
+
lora_state_dict, adapter_name="default", parameter_prefix="lora_"
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
model_state_dict = model.state_dict()
|
| 129 |
+
model_state_dict.update(non_lora_state_dict)
|
| 130 |
+
model_state_dict.update(lora_state_dict)
|
| 131 |
+
model.load_state_dict(model_state_dict)
|
| 132 |
+
|
| 133 |
+
return model, checkpoint_step
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def find_target_linear_names(
|
| 137 |
+
model, num_lora_modules=-1, lora_namespan_exclude=[], verbose=False
|
| 138 |
+
):
|
| 139 |
+
"""
|
| 140 |
+
Find the target linear modules for LoRA.
|
| 141 |
+
"""
|
| 142 |
+
linear_cls = torch.nn.Linear
|
| 143 |
+
embedding_cls = torch.nn.Embedding
|
| 144 |
+
lora_module_names = []
|
| 145 |
+
|
| 146 |
+
for name, module in model.named_modules():
|
| 147 |
+
if any(ex_keyword in name for ex_keyword in lora_namespan_exclude):
|
| 148 |
+
# print(f"Excluding module: {name}")
|
| 149 |
+
continue
|
| 150 |
+
|
| 151 |
+
if isinstance(module, (linear_cls, embedding_cls)):
|
| 152 |
+
lora_module_names.append(name)
|
| 153 |
+
|
| 154 |
+
if num_lora_modules > 0:
|
| 155 |
+
lora_module_names = lora_module_names[-num_lora_modules:]
|
| 156 |
+
if verbose:
|
| 157 |
+
print(f"Found {len(lora_module_names)} lora modules: {lora_module_names}")
|
| 158 |
+
return lora_module_names
|
pretrained_models/download_pretrained_models.sh
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# Create pretrained_models directory
|
| 4 |
+
mkdir -p pretrained_models
|
| 5 |
+
|
| 6 |
+
# Model list (using array instead of associative array to avoid ordering issues)
|
| 7 |
+
models=(
|
| 8 |
+
"black-forest-labs/FLUX.1-dev:FLUX1-dev"
|
| 9 |
+
"stabilityai/stable-diffusion-3-medium-diffusers:SD3-medium"
|
| 10 |
+
"stabilityai/stable-diffusion-xl-base-1.0:SDXL-base"
|
| 11 |
+
"Kwai-Kolors/Kolors-diffusers:Kolors"
|
| 12 |
+
"THUDM/CogView4-6B:CogView4"
|
| 13 |
+
"PixArt-alpha/PixArt-Sigma-XL-2-1024-MS:PixArt"
|
| 14 |
+
"Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers:HunyuanDiT"
|
| 15 |
+
"FoundationVision/Infinity"
|
| 16 |
+
"google/flan-t5-xl"
|
| 17 |
+
"Qwen/Qwen2-VL-7B-Instruct: Qwen2-VL-7B-Instruct"
|
| 18 |
+
"Qwen/Qwen2-VL-2B-Instruct: Qwen2-VL-2B-Instruct"
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
# Create tmux session and set up the first window
|
| 22 |
+
model_info="${models[0]}"
|
| 23 |
+
model_path="${model_info%:*}"
|
| 24 |
+
window_name="${model_info#*:}"
|
| 25 |
+
local_dir="${model_path##*/}"
|
| 26 |
+
|
| 27 |
+
# Set the first window name directly when creating session
|
| 28 |
+
tmux new-session -d -s download_pretrained_model -n "$window_name"
|
| 29 |
+
|
| 30 |
+
# Give tmux some time to initialize
|
| 31 |
+
sleep 0.5
|
| 32 |
+
|
| 33 |
+
# Set commands for the first window
|
| 34 |
+
tmux send-keys -t download_pretrained_model:"$window_name" "conda activate hpsv3" Enter
|
| 35 |
+
tmux send-keys -t download_pretrained_model:"$window_name" "export HF_ENDPOINT=https://alpha.hf-mirror.com" Enter
|
| 36 |
+
tmux send-keys -t download_pretrained_model:"$window_name" "cd pretrained_models" Enter
|
| 37 |
+
tmux send-keys -t download_pretrained_model:"$window_name" "while true; do huggingface-cli download $model_path --local-dir $local_dir && break || sleep 60; done" Enter
|
| 38 |
+
|
| 39 |
+
# Create new windows for remaining models
|
| 40 |
+
for i in $(seq 1 $((${#models[@]} - 1))); do
|
| 41 |
+
model_info="${models[$i]}"
|
| 42 |
+
model_path="${model_info%:*}"
|
| 43 |
+
window_name="${model_info#*:}"
|
| 44 |
+
local_dir="${model_path##*/}"
|
| 45 |
+
|
| 46 |
+
# Create new window
|
| 47 |
+
tmux new-window -t download_pretrained_model -n "$window_name"
|
| 48 |
+
# Add small delay to ensure window creation is complete
|
| 49 |
+
sleep 0.2
|
| 50 |
+
tmux send-keys -t download_pretrained_model:"$window_name" "conda activate hpsv3" Enter
|
| 51 |
+
tmux send-keys -t download_pretrained_model:"$window_name" "export HF_ENDPOINT=https://alpha.hf-mirror.com" Enter
|
| 52 |
+
tmux send-keys -t download_pretrained_model:"$window_name" "cd pretrained_models" Enter
|
| 53 |
+
tmux send-keys -t download_pretrained_model:"$window_name" "while true; do huggingface-cli download $model_path --local-dir $local_dir && break || sleep 60; done" Enter
|
| 54 |
+
done
|
| 55 |
+
# Switch to the first window (using the first model's window name)
|
| 56 |
+
first_window_name="${models[0]#*:}"
|
| 57 |
+
tmux select-window -t download_pretrained_model:"$first_window_name"
|
| 58 |
+
|
| 59 |
+
echo "Created tmux session 'download_pretrained_model' and started downloading all models"
|
| 60 |
+
echo "Use 'tmux attach -t download_pretrained_model' to view download progress"
|
| 61 |
+
echo "Use Ctrl+B then press number keys to switch between different download windows"
|
| 62 |
+
echo "Use 'tmux list-windows -t download_pretrained_model' to view all windows"
|
| 63 |
+
echo "Use 'tmux kill-session -t download_pretrained_model' to end the session"
|
pyproject.toml
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=45", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "hpsv3"
|
| 7 |
+
version = "1.0.0"
|
| 8 |
+
description = "HPSv3: Towards Wide-Spectrum Human Preference Score - A VLM-based preference model for image quality assessment"
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
requires-python = ">=3.8"
|
| 11 |
+
license = {text = "MIT"}
|
| 12 |
+
authors = [
|
| 13 |
+
{name = "Yunhao Shui"},
|
| 14 |
+
{name = "Yuhang Ma"},
|
| 15 |
+
]
|
| 16 |
+
keywords = ["machine learning", "computer vision", "human preference", "image quality", "VLM", "multimodal"]
|
| 17 |
+
classifiers = [
|
| 18 |
+
"Development Status :: 4 - Beta",
|
| 19 |
+
"Intended Audience :: Developers",
|
| 20 |
+
"Intended Audience :: Science/Research",
|
| 21 |
+
"License :: OSI Approved :: MIT License",
|
| 22 |
+
"Operating System :: OS Independent",
|
| 23 |
+
"Programming Language :: Python :: 3",
|
| 24 |
+
"Programming Language :: Python :: 3.8",
|
| 25 |
+
"Programming Language :: Python :: 3.9",
|
| 26 |
+
"Programming Language :: Python :: 3.10",
|
| 27 |
+
"Programming Language :: Python :: 3.11",
|
| 28 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
| 29 |
+
"Topic :: Software Development :: Libraries :: Python Modules",
|
| 30 |
+
]
|
| 31 |
+
dependencies = [
|
| 32 |
+
"torch>=2.0.0",
|
| 33 |
+
"torchvision>=0.15.0",
|
| 34 |
+
"transformers==4.45.2",
|
| 35 |
+
"accelerate>=0.20.0",
|
| 36 |
+
"datasets>=2.10.0",
|
| 37 |
+
"diffusers>=0.20.0",
|
| 38 |
+
"Pillow>=9.0.0",
|
| 39 |
+
"numpy>=1.20.0",
|
| 40 |
+
"tqdm>=4.60.0",
|
| 41 |
+
"pyyaml>=6.0",
|
| 42 |
+
"omegaconf>=2.3.0",
|
| 43 |
+
"opencv-python>=4.5.0",
|
| 44 |
+
"safetensors>=0.3.0",
|
| 45 |
+
"einops>=0.6.0",
|
| 46 |
+
"qwen-vl-utils>=0.0.8",
|
| 47 |
+
"timm>=0.9.0",
|
| 48 |
+
"deepspeed>=0.12.0",
|
| 49 |
+
"peft>=0.8.0",
|
| 50 |
+
"trl>=0.7.0",
|
| 51 |
+
"fire>=0.7.0"
|
| 52 |
+
]
|
| 53 |
+
|
| 54 |
+
[project.urls]
|
| 55 |
+
Homepage = "https://mizzenai.github.io/HPSv3/"
|
| 56 |
+
Source = "https://github.com/MizzenAI/HPSv3"
|
| 57 |
+
Documentation = "https://github.com/MizzenAI/HPSv3/blob/main/README.md"
|
| 58 |
+
Paper = "https://arxiv.org/abs/2411.07232"
|
| 59 |
+
|
| 60 |
+
[tool.setuptools.packages.find]
|
| 61 |
+
include = ["hpsv3*", "generate*", "evaluate*"]
|
| 62 |
+
|
| 63 |
+
[tool.setuptools.package-data]
|
| 64 |
+
hpsv3 = ["config/*.yaml", "config/ds_config/*.json"]
|
requirements.txt
CHANGED
|
@@ -190,6 +190,6 @@ widgetsnbextension==4.0.14
|
|
| 190 |
xxhash==3.5.0
|
| 191 |
yarl==1.20.1
|
| 192 |
zipp==3.22.0
|
|
|
|
| 193 |
hpsv3
|
| 194 |
-
hpsv2
|
| 195 |
-
# flash-attn==2.7.4.post1
|
|
|
|
| 190 |
xxhash==3.5.0
|
| 191 |
yarl==1.20.1
|
| 192 |
zipp==3.22.0
|
| 193 |
+
# flash-attn==2.7.4.post1
|
| 194 |
hpsv3
|
| 195 |
+
hpsv2
|
|
|