sdsdgwe commited on
Commit
9b57ce7
·
1 Parent(s): ef8733b
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 HPSv3 Team
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
environment.yaml ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: hpsv3
2
+ channels:
3
+ - nvidia
4
+ - conda-forge
5
+ - defaults
6
+ dependencies:
7
+ - _libgcc_mutex=0.1=conda_forge
8
+ - _openmp_mutex=4.5=2_gnu
9
+ - bzip2=1.0.8=h4bc722e_7
10
+ - ca-certificates=2025.4.26=hbd8a1cb_0
11
+ - ld_impl_linux-64=2.43=h712a8e2_4
12
+ - libexpat=2.7.0=h5888daf_0
13
+ - libffi=3.4.6=h2dba641_1
14
+ - libgcc=15.1.0=h767d61c_2
15
+ - libgcc-ng=15.1.0=h69a702a_2
16
+ - libgomp=15.1.0=h767d61c_2
17
+ - liblzma=5.8.1=hb9d3cd8_1
18
+ - libnsl=2.0.1=hd590300_0
19
+ - libsqlite=3.50.0=hee588c1_0
20
+ - libuuid=2.38.1=h0b41bf4_0
21
+ - libxcrypt=4.4.36=hd590300_1
22
+ - libzlib=1.3.1=hb9d3cd8_2
23
+ - ncurses=6.5=h2d0b736_3
24
+ - openssl=3.5.0=h7b32b05_1
25
+ - pip=25.1.1=pyh8b19718_0
26
+ - python=3.10.17=hd6af730_0_cpython
27
+ - readline=8.2=h8c095d6_2
28
+ - setuptools=80.8.0=pyhff2d567_0
29
+ - tk=8.6.13=noxft_hd72426e_102
30
+ - wheel=0.45.1=pyhd8ed1ab_1
31
+ - pip:
32
+ - absl-py==2.3.0
33
+ - accelerate==1.8.0
34
+ - aiohappyeyeballs==2.6.1
35
+ - aiohttp==3.12.12
36
+ - aiosignal==1.3.2
37
+ - annotated-types==0.7.0
38
+ - antlr4-python3-runtime==4.9.3
39
+ - anyio==4.9.0
40
+ - argon2-cffi==23.1.0
41
+ - argon2-cffi-bindings==21.2.0
42
+ - arrow==1.3.0
43
+ - asttokens==3.0.0
44
+ - async-lru==2.0.5
45
+ - async-timeout==5.0.1
46
+ - attrs==25.3.0
47
+ - av==14.4.0
48
+ - babel==2.17.0
49
+ - beautifulsoup4==4.13.4
50
+ - bleach==6.2.0
51
+ - botocore==1.38.35
52
+ - certifi==2025.4.26
53
+ - cffi==1.17.1
54
+ - charset-normalizer==3.4.2
55
+ - comm==0.2.2
56
+ - contourpy==1.3.2
57
+ - cycler==0.12.1
58
+ - datasets==3.6.0
59
+ - debugpy==1.8.14
60
+ - decorator==5.2.1
61
+ - deepspeed==0.15.4
62
+ - defusedxml==0.7.1
63
+ - diffusers==0.33.1
64
+ - dill==0.3.8
65
+ - docstring-parser==0.16
66
+ - einops==0.8.1
67
+ - exceptiongroup==1.3.0
68
+ - executing==2.2.0
69
+ - fastjsonschema==2.21.1
70
+ - filelock==3.13.1
71
+ - fire==0.7.0
72
+ - fonttools==4.58.1
73
+ - fqdn==1.5.1
74
+ - frozenlist==1.7.0
75
+ - fsspec==2024.6.1
76
+ - grpcio==1.72.1
77
+ - h11==0.16.0
78
+ - hf-xet==1.1.3
79
+ - hjson==3.1.0
80
+ - httpcore==1.0.9
81
+ - httpx==0.28.1
82
+ - huggingface-hub==0.32.4
83
+ - idna==3.10
84
+ - imageio==2.37.0
85
+ - importlib-metadata==8.7.0
86
+ - ipykernel==6.29.5
87
+ - ipython==8.36.0
88
+ - ipywidgets==8.1.7
89
+ - isoduration==20.11.0
90
+ - jedi==0.19.2
91
+ - jinja2==3.1.6
92
+ - jmespath==1.0.1
93
+ - json5==0.12.0
94
+ - jsonpointer==3.0.0
95
+ - jsonschema==4.24.0
96
+ - jsonschema-specifications==2025.4.1
97
+ - jupyter==1.1.1
98
+ - jupyter-client==8.6.3
99
+ - jupyter-console==6.6.3
100
+ - jupyter-core==5.8.1
101
+ - jupyter-events==0.12.0
102
+ - jupyter-lsp==2.2.5
103
+ - jupyter-server==2.16.0
104
+ - jupyter-server-terminals==0.5.3
105
+ - jupyterlab==4.4.3
106
+ - jupyterlab-pygments==0.3.0
107
+ - jupyterlab-server==2.27.3
108
+ - jupyterlab-widgets==3.0.15
109
+ - kiwisolver==1.4.8
110
+ - markdown==3.8
111
+ - markdown-it-py==3.0.0
112
+ - markupsafe==3.0.2
113
+ - matplotlib==3.10.3
114
+ - matplotlib-inline==0.1.7
115
+ - mdurl==0.1.2
116
+ - mistune==3.1.3
117
+ - mpmath==1.3.0
118
+ - msgpack==1.1.0
119
+ - multidict==6.4.4
120
+ - multiprocess==0.70.16
121
+ - nbclient==0.10.2
122
+ - nbconvert==7.16.6
123
+ - nbformat==5.10.4
124
+ - nest-asyncio==1.6.0
125
+ - networkx==3.3
126
+ - ninja==1.11.1.4
127
+ - notebook==7.4.3
128
+ - notebook-shim==0.2.4
129
+ - numpy==2.1.2
130
+ - nvidia-cublas-cu11==11.11.3.6
131
+ - nvidia-cuda-cupti-cu11==11.8.87
132
+ - nvidia-cuda-nvrtc-cu11==11.8.89
133
+ - nvidia-cuda-runtime-cu11==11.8.89
134
+ - nvidia-cudnn-cu11==9.1.0.70
135
+ - nvidia-cufft-cu11==10.9.0.58
136
+ - nvidia-curand-cu11==10.3.0.86
137
+ - nvidia-cusolver-cu11==11.4.1.48
138
+ - nvidia-cusparse-cu11==11.7.5.86
139
+ - nvidia-ml-py==12.575.51
140
+ - nvidia-nccl-cu11==2.21.5
141
+ - nvidia-nvtx-cu11==11.8.86
142
+ - omegaconf==2.3.0
143
+ - opencv-python==4.11.0.86
144
+ - overrides==7.7.0
145
+ - packaging==25.0
146
+ - pandas==2.3.0
147
+ - pandocfilters==1.5.1
148
+ - parso==0.8.4
149
+ - peft==0.10.0
150
+ - pexpect==4.9.0
151
+ - pillow==11.0.0
152
+ - platformdirs==4.3.8
153
+ - prometheus-client==0.22.0
154
+ - prompt-toolkit==3.0.51
155
+ - propcache==0.3.2
156
+ - protobuf==6.31.1
157
+ - psutil==7.0.0
158
+ - ptyprocess==0.7.0
159
+ - pure-eval==0.2.3
160
+ - py-cpuinfo==9.0.0
161
+ - pyarrow==20.0.0
162
+ - pycparser==2.22
163
+ - pydantic==2.11.5
164
+ - pydantic-core==2.33.2
165
+ - pygments==2.19.1
166
+ - pyparsing==3.2.3
167
+ - python-dateutil==2.9.0.post0
168
+ - python-json-logger==3.3.0
169
+ - pytz==2025.2
170
+ - pyyaml==6.0.2
171
+ - pyzmq==26.4.0
172
+ - prettytable==3.8.0
173
+ - qwen-vl-utils==0.0.11
174
+ - referencing==0.36.2
175
+ - regex==2024.11.6
176
+ - requests==2.32.3
177
+ - rfc3339-validator==0.1.4
178
+ - rfc3986-validator==0.1.1
179
+ - rich==14.0.0
180
+ - rpds-py==0.25.1
181
+ - safetensors==0.5.3
182
+ - send2trash==1.8.3
183
+ - sentencepiece==0.2.0
184
+ - shtab==1.7.2
185
+ - six==1.17.0
186
+ - sniffio==1.3.1
187
+ - soupsieve==2.7
188
+ - stack-data==0.6.3
189
+ - sympy==1.13.1
190
+ - tensorboard==2.19.0
191
+ - tensorboard-data-server==0.7.2
192
+ - termcolor==3.1.0
193
+ - terminado==0.18.1
194
+ - timm==1.0.15
195
+ - tinycss2==1.4.0
196
+ - tokenizers==0.20.3
197
+ - tomli==2.2.1
198
+ - torch==2.6.0
199
+ - torchaudio==2.6.0
200
+ - torchvision==0.21.0
201
+ - tornado==6.5.1
202
+ - tqdm==4.67.1
203
+ - traitlets==5.14.3
204
+ - transformers==4.45.2
205
+ - triton==3.2.0
206
+ - trl==0.8.6
207
+ - typeguard==4.4.3
208
+ - types-python-dateutil==2.9.0.20250516
209
+ - typing-extensions==4.14.0
210
+ - typing-inspection==0.4.1
211
+ - tyro==0.9.24
212
+ - tzdata==2025.2
213
+ - uri-template==1.3.0
214
+ - urllib3==2.4.0
215
+ - wcwidth==0.2.13
216
+ - webcolors==24.11.1
217
+ - webencodings==0.5.1
218
+ - websocket-client==1.8.0
219
+ - werkzeug==3.1.3
220
+ - widgetsnbextension==4.0.14
221
+ - xxhash==3.5.0
222
+ - yarl==1.20.1
223
+ - zipp==3.22.0
evaluate/README.md ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Model Performance Evaluation (`evaluate.py`)
2
+
3
+ This script is used to evaluate the model's performance on a test set. It can operate in two modes:
4
+
5
+ - **`pair`**: Calculates pairwise accuracy.
6
+ - **`ranking`**: Calculates ranking accuracy.
7
+
8
+ **Pair-wise Sample**
9
+
10
+ We set path1's image is better than path2's image for simplicity.
11
+
12
+ ```json
13
+ [
14
+ {
15
+ "prompt": ".....",
16
+ "path1": ".....",
17
+ "path2": "....."
18
+ },
19
+ {
20
+ "prompt": ".....",
21
+ "path1": ".....",
22
+ "path2": "....."
23
+ },
24
+ ...
25
+ ]
26
+ ```
27
+
28
+ **Rank-wise Sample**
29
+
30
+ ```json
31
+ [
32
+ {
33
+ "id": "005658-0040",
34
+ "prompt": ".....",
35
+ "generations": [
36
+ "path to image1",
37
+ "path to image2",
38
+ "path to image3",
39
+ "path to image4"
40
+ ],
41
+ "ranking": [
42
+ 1,
43
+ 2,
44
+ 5,
45
+ 3
46
+ ]
47
+ },
48
+ ...
49
+ ]
50
+ ```
51
+
52
+ ### Usage
53
+
54
+ ```bash
55
+ python evaluate/evaluate.py \
56
+ --test_json /path/to/your/test_data.json \
57
+ --config_path config/HPSv3_7B.yaml \
58
+ --checkpoint_path checkpoints/HPSv3_7B/model.pth \
59
+ --mode pair \
60
+ --batch_size 8 \
61
+ --num_processes 8
62
+ ```
63
+
64
+ **Arguments:**
65
+
66
+ - `--test_json`: (Required) Path to the JSON file containing evaluation data.
67
+ - `--config_path`: (Required) Path to the model's configuration file.
68
+ - `--checkpoint_path`: (Required) Path to the model checkpoint.
69
+ - `--mode`: The evaluation mode. Can be `pair` or `ranking`. (Default: `pair`)
70
+ - `--batch_size`: Batch size for inference. (Default: 8)
71
+ - `--num_processes`: Number of parallel processes to use. (Default: 8)
72
+
73
+ ---
74
+
75
+ ## Reward Benchmarking (`benchmark.py`)
76
+
77
+ This script is used to run inference with a reward model over one or more folders of images. It calculates a reward score for each image based on its corresponding text prompt (expected in a `.txt` file with the same name). The script then outputs statistics (mean, std, min, max) for each folder and saves the detailed results to a JSON file.
78
+
79
+ It supports multiple reward models through the `--model_type` argument.
80
+
81
+ ### Usage
82
+
83
+ The script is run using `argparse`. Below is a command-line example:
84
+
85
+ ```bash
86
+ python evaluate/benchmark.py \
87
+ --config_path config/HPSv3_7B.yaml \
88
+ --checkpoint_path checkpoints/HPSv3_7B/model.pth \
89
+ --model_type hpsv3 \
90
+ --image_folders /path/to/images/folder1 /path/to/images/folder2 \
91
+ --output_path ./benchmark_results.json \
92
+ --batch_size 16 \
93
+ --num_processes 8
94
+ ```
95
+
96
+ **Arguments:**
97
+
98
+ - `--config_path`: (Required) Path to the model's configuration file.
99
+ - `--checkpoint_path`: (Required) Path to the model checkpoint.
100
+ - `--model_type`: The reward model to use. Choices: `hpsv3`, `hpsv2`, `imagereward`. (Default: `hpsv3`)
101
+ - `--image_folders`: (Required) One or more paths to folders containing the images to benchmark.
102
+ - `--output_path`: (Required) Path to save the output JSON file with results.
103
+ - `--batch_size`: Batch size for processing. (Default: 16)
104
+ - `--num_processes`: Number of parallel processes to use. (Default: 8)
105
+ - `--num_machines`: For distributed inference, the total number of machines. (Default: 1)
106
+ - `--machine_id`: For distributed inference, the ID of the current machine. (Default: 0)
107
+
evaluate/benchmark.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import multiprocessing as mp
5
+ from tqdm import tqdm
6
+ from hpsv3.inference import HPSv3RewardInferencer
7
+ import argparse
8
+ from collections import defaultdict
9
+ import glob
10
+ import numpy as np
11
+ from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer
12
+ from PIL import Image
13
+ import ImageReward as RM
14
+ from transformers import AutoProcessor, AutoModel
15
+ def initialize_model_hpsv2(device, cp):
16
+ model_dict = {}
17
+ model, preprocess_train, preprocess_val = create_model_and_transforms(
18
+ 'ViT-H-14',
19
+ 'laion2B-s32B-b79K',
20
+ precision='amp',
21
+ device=device,
22
+ jit=False,
23
+ force_quick_gelu=False,
24
+ force_custom_text=False,
25
+ force_patch_dropout=False,
26
+ force_image_size=None,
27
+ pretrained_image=False,
28
+ image_mean=None,
29
+ image_std=None,
30
+ light_augmentation=True,
31
+ aug_cfg={},
32
+ output_dict=True,
33
+ with_score_predictor=False,
34
+ with_region_predictor=False
35
+ )
36
+
37
+ checkpoint = torch.load(cp, map_location=device, weights_only=False)
38
+ model.load_state_dict(checkpoint['state_dict'])
39
+ model = model.to(device)
40
+ model.eval()
41
+ tokenizer = get_tokenizer('ViT-H-14')
42
+
43
+ model_dict['model'] = model
44
+ model_dict['preprocess_val'] = preprocess_val
45
+ return model_dict, tokenizer
46
+
47
+ def initialize_pickscore(device, checkpoint_path):
48
+ processor = AutoProcessor.from_pretrained('laion/CLIP-ViT-H-14-laion2B-s32B-b79K')
49
+ model = AutoModel.from_pretrained(checkpoint_path).eval().to(device)
50
+ return model, processor
51
+
52
+ def initialize_aesthetic_model():
53
+ import open_clip
54
+ from os.path import expanduser
55
+ from urllib.request import urlretrieve
56
+ import torch.nn as nn
57
+
58
+ def get_aesthetic_model(clip_model="vit_l_14"):
59
+ """Load the aesthetic model with caching"""
60
+
61
+ home = expanduser("~")
62
+ cache_folder = home + "/.cache/emb_reader"
63
+ path_to_model = cache_folder + "/sa_0_4_"+clip_model+"_linear.pth"
64
+ if not os.path.exists(path_to_model):
65
+ os.makedirs(cache_folder, exist_ok=True)
66
+ url_model = (
67
+ "https://github.com/LAION-AI/aesthetic-predictor/blob/main/sa_0_4_"+clip_model+"_linear.pth?raw=true"
68
+ )
69
+ urlretrieve(url_model, path_to_model)
70
+ # Create appropriate linear layer
71
+ if clip_model == "vit_l_14":
72
+ m = nn.Linear(768, 1)
73
+ elif clip_model == "vit_b_32":
74
+ m = nn.Linear(512, 1)
75
+ else:
76
+ raise ValueError()
77
+ m.load_state_dict(torch.load(path_to_model))
78
+ m.eval()
79
+ return m
80
+
81
+ model, _, preprocess = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai')
82
+ amodel = get_aesthetic_model(clip_model="vit_l_14")
83
+ return model, preprocess, amodel
84
+
85
+ def initialize_clip(device):
86
+ """Initialize the CLIP model and processor."""
87
+ model = AutoModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
88
+ processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
89
+ return model.to(device), processor
90
+
91
+ def score_hpsv2_batch(model_dict, tokenizer, device, img_paths: list, prompts: list) -> list:
92
+ model = model_dict['model']
93
+ preprocess_val = model_dict['preprocess_val']
94
+
95
+ # 批量处理图片
96
+ images = [preprocess_val(Image.open(p)).unsqueeze(0)[:,:3,:,:] for p in img_paths]
97
+ images = torch.cat(images, dim=0).to(device=device)
98
+ texts = tokenizer(prompts).to(device=device)
99
+ with torch.no_grad():
100
+ outputs = model(images, texts)
101
+ image_features, text_features = outputs["image_features"], outputs["text_features"]
102
+ logits_per_image = image_features @ text_features.T
103
+ hps_scores = torch.diagonal(logits_per_image).cpu()
104
+ return hps_scores
105
+
106
+ def score_pick_score_batch(prompts, images, model, processor, device):
107
+ # preprocess
108
+ pil_images = [Image.open(p) for p in images]
109
+ image_inputs = processor(
110
+ images=pil_images,
111
+ padding=True,
112
+ truncation=True,
113
+ max_length=77,
114
+ return_tensors="pt",
115
+ ).to(device)
116
+
117
+ text_inputs = processor(
118
+ text=prompts,
119
+ padding=True,
120
+ truncation=True,
121
+ max_length=77,
122
+ return_tensors="pt",
123
+ ).to(device)
124
+
125
+ with torch.no_grad():
126
+ # embed
127
+ image_embs = model.get_image_features(**image_inputs)
128
+ image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
129
+
130
+ text_embs = model.get_text_features(**text_inputs)
131
+ text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)
132
+ # score
133
+ scores = model.logit_scale.exp() * (text_embs @ image_embs.T)
134
+ scores = torch.diagonal(scores).cpu()
135
+
136
+ return scores
137
+
138
+
139
+ def score_aesthetic_batch(model, preprocess, aesthetic_model, device, img_paths: list) -> list:
140
+ """Scores a batch of images using the aesthetic model."""
141
+ images = [preprocess(Image.open(p)).unsqueeze(0) for p in img_paths]
142
+ images = torch.cat(images, dim=0).to(device=device)
143
+ with torch.no_grad():
144
+ feat = model.encode_image(images)
145
+ feat = feat / feat.norm(dim=-1, keepdim=True)
146
+ pred = aesthetic_model(feat).cpu()
147
+ return pred
148
+
149
+ def score_clip_batch(model, processor, device, img_paths: list, prompts: list) -> list:
150
+ """Scores a batch of images against prompts using CLIP."""
151
+ # preprocess
152
+ pil_images = [Image.open(p) for p in img_paths]
153
+ image_inputs = processor(
154
+ images=pil_images,
155
+ padding=True,
156
+ truncation=True,
157
+ max_length=77,
158
+ return_tensors="pt",
159
+ ).to(device)
160
+
161
+ text_inputs = processor(
162
+ text=prompts,
163
+ padding=True,
164
+ truncation=True,
165
+ max_length=77,
166
+ return_tensors="pt",
167
+ ).to(device)
168
+
169
+ with torch.no_grad():
170
+ # embed
171
+ image_embs = model.get_image_features(**image_inputs)
172
+ image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
173
+
174
+ text_embs = model.get_text_features(**text_inputs)
175
+ text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)
176
+ # score
177
+ scores = image_embs @ text_embs.T
178
+ scores = torch.diagonal(scores).cpu()
179
+
180
+ return scores
181
+
182
+ def calculate_category_stats(data_dict):
183
+ """Calculate statistics for each category"""
184
+ stats = {}
185
+ for category, data_list in data_dict.items():
186
+ if not data_list:
187
+ stats[category] = {
188
+ 'count': 0,
189
+ 'mean': 0.0,
190
+ 'std': 0.0,
191
+ 'min': 0.0,
192
+ 'max': 0.0
193
+ }
194
+ continue
195
+
196
+ rewards = [item['reward'] for item in data_list]
197
+ stats[category] = {
198
+ 'count': len(rewards),
199
+ 'mean': float(np.mean(rewards)),
200
+ 'std': float(np.std(rewards)),
201
+ 'min': float(np.min(rewards)),
202
+ 'max': float(np.max(rewards))
203
+ }
204
+ total_mean = np.mean([stat['mean'] for stat in stats.values() if stat['count'] > 0])
205
+ stats['OVERALL'] = {
206
+ 'count': sum(stat['count'] for stat in stats.values()),
207
+ 'mean': float(total_mean),
208
+ 'std': float(np.std([stat['mean'] for stat in stats.values() if stat['count'] > 0])),
209
+ 'min': float(min(stat['min'] for stat in stats.values() if stat['count'] > 0)),
210
+ 'max': float(max(stat['max'] for stat in stats.values() if stat['count'] > 0))
211
+ }
212
+ return stats
213
+
214
+ def print_stats(stats):
215
+ print(f"{'Category':<30} {'Count':<8} {'Mean':<10} {'Std':<10} {'Min':<10} {'Max':<10}")
216
+ print("-" * 78)
217
+ for category, stat in stats.items():
218
+ category_name = category # Get folder name only
219
+ print(f"{category_name:<30} {stat['count']:<8} {stat['mean']:<10.4f} {stat['std']:<10.4f} {stat['min']:<10.4f} {stat['max']:<10.4f}")
220
+
221
+ # Calculate overall statistics
222
+ if stats:
223
+ all_counts = [stat['count'] for stat in stats.values()]
224
+ all_means = [stat['mean'] for stat in stats.values() if stat['count'] > 0]
225
+ if all_means:
226
+ print("-" * 78)
227
+ print(f"{'OVERALL':<30} {sum(all_counts):<8} {np.mean(all_means):<10.4f} {'':<10} {min([stat['min'] for stat in stats.values() if stat['count'] > 0]):<10.4f} {max([stat['max'] for stat in stats.values() if stat['count'] > 0]):<10.4f}")
228
+
229
+ def worker_process(process_id, process_dict, config_path, checkpoint_path, mode, device_id, dtype, batch_size, return_dict):
230
+ """Worker process function that processes a chunk of data"""
231
+ category_rewards = defaultdict(list)
232
+
233
+ device = f"cuda:{device_id}" if torch.cuda.is_available() else "cpu"
234
+ if mode == 'imagereward':
235
+ model = RM.load("ImageReward-v1.0")
236
+ elif mode == 'hpsv2':
237
+ inferencer = initialize_model_hpsv2(device, checkpoint_path)
238
+ model_dict, tokenizer = inferencer
239
+ elif mode == 'hpsv3':
240
+ inferencer = HPSv3RewardInferencer(config_path=config_path, checkpoint_path=checkpoint_path,device=device)
241
+ elif mode == 'pickscore':
242
+ model, processor = initialize_pickscore(device, checkpoint_path)
243
+ elif mode == 'aesthetic':
244
+ model, preprocess, aesthetic_model = initialize_aesthetic_model()
245
+ model = model.to(device)
246
+ aesthetic_model = aesthetic_model.to(device)
247
+ elif mode == 'clip':
248
+ model, processor = initialize_clip(device)
249
+ model = model.to(device)
250
+ else:
251
+ raise ValueError(f"Unsupported mode: {mode}")
252
+
253
+ for category, chunk_data in tqdm(process_dict.items(), total=len(process_dict), desc='Total', disable=not process_id == 0):
254
+ processed_data = []
255
+ # Process data in batches
256
+ for batch_start in tqdm(range(0, len(chunk_data), batch_size),
257
+ total=(len(chunk_data) + batch_size - 1) // batch_size,
258
+ desc=f"Category {category}", disable=not process_id == 0):
259
+ batch_end = min(batch_start + batch_size, len(chunk_data))
260
+ image_paths = chunk_data[batch_start:batch_end]
261
+ text_paths = [p[:-4]+'.txt' for p in image_paths]
262
+
263
+ prompts = ['\n'.join(open(p, 'r').readlines()) for p in text_paths]
264
+
265
+ with torch.no_grad():
266
+ if mode == 'imagereward':
267
+ rewards = torch.tensor([model.score(prompt, image_path) for prompt, image_path in zip(prompts, image_paths)])
268
+ elif mode == 'hpsv2':
269
+ rewards = score_hpsv2_batch(model_dict, tokenizer, device, image_paths, prompts)
270
+ elif mode == 'hpsv3':
271
+ rewards = inferencer.reward(image_paths, prompts)
272
+ elif mode == 'pickscore':
273
+ rewards = score_pick_score_batch(prompts, image_paths, model, processor, device)
274
+ elif mode == 'aesthetic':
275
+ rewards = score_aesthetic_batch(model, preprocess, aesthetic_model, device, image_paths)
276
+ elif mode == 'clip':
277
+ rewards = score_clip_batch(model, processor, device, image_paths, prompts)
278
+ else:
279
+ raise ValueError(f"Unsupported mode: {mode}")
280
+
281
+ torch.cuda.empty_cache()
282
+ for i in range(len(image_paths)):
283
+ if rewards.ndim == 2:
284
+ reward = rewards[i][0].item()
285
+ else:
286
+ reward = rewards[i].item()
287
+ processed_data.append({
288
+ 'image_path': image_paths[i],
289
+ 'reward': reward,
290
+ 'prompt': prompts[i]
291
+ })
292
+
293
+ category_rewards[category] = processed_data
294
+
295
+ return_dict[process_id] = {
296
+ 'data': category_rewards,
297
+ }
298
+
299
+ def chunk_list(data_list, num_chunks):
300
+ """Split list into roughly equal chunks"""
301
+ chunk_size = len(data_list) // num_chunks
302
+ remainder = len(data_list) % num_chunks
303
+
304
+ chunks = []
305
+ start = 0
306
+ for i in range(num_chunks):
307
+ # Add one extra item to first 'remainder' chunks
308
+ current_chunk_size = chunk_size + (1 if i < remainder else 0)
309
+ end = start + current_chunk_size
310
+ chunks.append(data_list[start:end])
311
+ start = end
312
+
313
+ return chunks
314
+
315
+ def main(config_path, checkpint_path, mode, image_folders, output_path, batch_size=16, num_processes=8, num_machines=1, machine_id=0):
316
+ print(f"Config path: {config_path}")
317
+
318
+ dtype = torch.bfloat16
319
+
320
+ # Gather all data first
321
+ folder_dict = {}
322
+ for folder in image_folders:
323
+ images = []
324
+ for ext in ['.png', '.jpg']:
325
+ images.extend(glob.glob(os.path.join(folder, "**", f"*{ext}"), recursive=True))
326
+ machine_image_chunks = chunk_list(images, num_machines)
327
+ image_list = machine_image_chunks[machine_id] if machine_id < len(machine_image_chunks) else []
328
+ print(f"Folder {folder} total data points: {len(image_list)}")
329
+ data_chunks = chunk_list(image_list, num_processes)
330
+ print(f"Folder {folder} data split into {num_processes} chunks with sizes: {[len(chunk) for chunk in data_chunks]}")
331
+ folder_dict[folder] = data_chunks
332
+
333
+ per_process_folder_dict = []
334
+ for i in range(num_processes):
335
+ one_dict = {}
336
+ for key, value in folder_dict.items():
337
+ one_dict[key] = value[i] if i < len(value) else []
338
+ per_process_folder_dict.append(one_dict)
339
+
340
+ # Create manager for shared data between processes
341
+ with mp.Manager() as manager:
342
+ return_dict = manager.dict()
343
+ processes = []
344
+
345
+ # Start processes
346
+ for i in range(num_processes):
347
+ device_id = i % torch.cuda.device_count() if torch.cuda.is_available() else 0
348
+
349
+ p = mp.Process(target=worker_process,
350
+ args=(i, per_process_folder_dict[i], config_path, checkpint_path, mode, device_id, dtype, batch_size, return_dict))
351
+ p.start()
352
+ processes.append(p)
353
+
354
+ for p in processes:
355
+ p.join()
356
+
357
+ # Collect results from all processes
358
+ all_processed_data = {}
359
+ for i in range(num_processes):
360
+ if i in return_dict:
361
+ result = return_dict[i]
362
+ process_data = result['data']
363
+ # Merge data from each process
364
+ for category, data_list in process_data.items():
365
+ if category not in all_processed_data:
366
+ all_processed_data[category] = []
367
+ all_processed_data[category].extend(data_list)
368
+ else:
369
+ print(f"No result from process {i}")
370
+
371
+ # Calculate and print statistics for current machine
372
+ if all_processed_data:
373
+ stats = calculate_category_stats(all_processed_data)
374
+ print(f"\n=== Machine {machine_id} Statistics ===")
375
+ print_stats(stats)
376
+
377
+ # Save results
378
+ if num_machines > 1:
379
+ # Save current machine's results
380
+ machine_output_path = output_path.replace('.json', f'_machine_{machine_id}.json')
381
+ with open(machine_output_path, "w") as f:
382
+ json.dump(all_processed_data, f, indent=4)
383
+ print(f"Machine {machine_id} results saved to {machine_output_path}")
384
+
385
+ # If this is machine 0, try to gather results from all machines
386
+ if machine_id == 0:
387
+ print("Waiting for all machines to complete...")
388
+ # Note: In practice, you might want to implement a proper synchronization mechanism
389
+ # For now, this assumes all machine files exist
390
+ final_result = {}
391
+ for i in range(num_machines):
392
+ machine_file = output_path.replace('.json', f'_machine_{i}.json')
393
+ if os.path.exists(machine_file):
394
+ print(f"Loading results from machine {i}")
395
+ with open(machine_file, 'r') as f:
396
+ machine_data = json.load(f)
397
+ # Merge machine data
398
+ for category, data_list in machine_data.items():
399
+ if category not in final_result:
400
+ final_result[category] = []
401
+ final_result[category].extend(data_list)
402
+ else:
403
+ print(f"Warning: Machine {i} results file not found: {machine_file}")
404
+
405
+ # Calculate and print statistics for final results
406
+ stats = calculate_category_stats(final_result)
407
+ print("\n=== Final Combined Statistics ===")
408
+ print_stats(stats)
409
+
410
+ # Save final combined results with statistics
411
+ final_output = {
412
+ 'statistics': stats,
413
+ 'data': final_result,
414
+ }
415
+ with open(output_path, "w") as f:
416
+ json.dump(final_output, f, indent=4)
417
+ print(f"Final combined results saved to {output_path}")
418
+ else:
419
+ # Single machine case - calculate statistics
420
+ stats = calculate_category_stats(all_processed_data)
421
+ print("\n=== Statistics ===")
422
+ print_stats(stats)
423
+
424
+ # Save results with statistics
425
+ output_data = {
426
+ 'statistics': stats,
427
+ 'data': all_processed_data,
428
+ }
429
+ with open(output_path, "w") as f:
430
+ json.dump(output_data, f, indent=4)
431
+ print(f"Results saved to {output_path}")
432
+
433
+
434
+ def parse_args():
435
+ parser = argparse.ArgumentParser(description='Process images with HPSv3 reward inference')
436
+ parser.add_argument('--config_path', type=str, help='Path to the configuration file')
437
+ parser.add_argument('--checkpoint_path', type=str, help='Path to the model checkpoint file')
438
+ parser.add_argument('--mode', type=str, choices=['imagereward','hpsv2', 'hpsv3', 'pickscore', 'aesthetic', 'clip'], default='hpsv3')
439
+ parser.add_argument('--image_folders', type=str, nargs='+', required=True, help='List of image folder paths to process')
440
+ parser.add_argument('--output_path', type=str, required=True, help='Path to save the output JSON file')
441
+ parser.add_argument('--batch_size', type=int, default=16, help='Batch size for processing (default: 16)')
442
+ parser.add_argument('--num_processes', type=int, default=8, help='Number of processes to use (default: 8)')
443
+ parser.add_argument('--num_machines', type=int, default=1, help='Total number of machines (default: 1)')
444
+ parser.add_argument('--machine_id', type=int, default=0, help='ID of current machine (default: 0)')
445
+
446
+ return parser.parse_args()
447
+
448
+
449
+ if __name__ == "__main__":
450
+ mp.set_start_method('spawn', force=True)
451
+
452
+ args = parse_args()
453
+ main(
454
+ config_path=args.config_path,
455
+ checkpint_path=args.checkpoint_path,
456
+ mode=args.mode,
457
+ image_folders=args.image_folders,
458
+ output_path=args.output_path,
459
+ batch_size=args.batch_size,
460
+ num_processes=args.num_processes,
461
+ num_machines=args.num_machines,
462
+ machine_id=args.machine_id
463
+ )
evaluate/evaluate.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import json
4
+ import torch
5
+ import multiprocessing as mp
6
+ from tqdm import tqdm
7
+ from hpsv3.inference import HPSv3RewardInferencer
8
+ from multiprocessing import Process, Queue
9
+ import math
10
+ import fire
11
+ import prettytable
12
+
13
+ def calc_rank_acc(score_sample, predict_sample):
14
+ tol_cnt = 0.
15
+ true_cnt = 0.
16
+ for idx in range(len(score_sample)):
17
+ item_base = score_sample[idx]["ranking"]
18
+ item = predict_sample[idx]["rewards"]
19
+ for i in range(len(item_base)):
20
+ for j in range(i+1, len(item_base)):
21
+ if item_base[i] > item_base[j]:
22
+ if item[i] >= item[j]:
23
+ tol_cnt += 1
24
+ elif item[i] < item[j]:
25
+ tol_cnt += 1
26
+ true_cnt += 1
27
+ elif item_base[i] < item_base[j]:
28
+ if item[i] > item[j]:
29
+ tol_cnt += 1
30
+ true_cnt += 1
31
+ elif item[i] <= item[j]:
32
+ tol_cnt += 1
33
+ return true_cnt / tol_cnt
34
+
35
+
36
+ def worker_process(process_id, data_chunk, config_path, checkpoint_path, batch_size, result_queue, mode):
37
+ """
38
+ Worker function for each process to handle a chunk of data
39
+ """
40
+
41
+ # Each process uses a different GPU (cycle through available GPUs)
42
+ num_gpus = torch.cuda.device_count()
43
+ device = f"cuda:{process_id % num_gpus}" if num_gpus > 0 else "cpu"
44
+ dtype = torch.bfloat16
45
+
46
+ print(f"Process {process_id} starting with device {device}, processing {len(data_chunk)} items")
47
+
48
+ # Initialize model for this process
49
+ inferencer = HPSv3RewardInferencer(config_path, checkpoint_path, device=device, dtype=dtype)
50
+
51
+ process_correct = 0
52
+ process_equal = 0
53
+ process_results = []
54
+
55
+ for batch_start in tqdm(range(0, len(data_chunk), batch_size),
56
+ total=(len(data_chunk) + batch_size - 1) // batch_size,
57
+ desc=f"Process {process_id}"):
58
+ batch_end = min(batch_start + batch_size, len(data_chunk))
59
+ batch_info = data_chunk[batch_start:batch_end]
60
+ if mode == 'pair':
61
+ image_paths_1 = [info["path1"] for info in batch_info]
62
+ image_paths_2 = [info["path2"] for info in batch_info]
63
+ prompts = [info["prompt"] for info in batch_info]
64
+
65
+ with torch.no_grad():
66
+ rewards_1 = inferencer.reward(image_paths_1, prompts)
67
+ rewards_2 = inferencer.reward(image_paths_2, prompts)
68
+
69
+ for i in range(len(batch_info)):
70
+ info = batch_info[i]
71
+ if rewards_1.ndim == 2:
72
+ reward_1, reward_2 = rewards_1[i][0].item(), rewards_2[i][0].item()
73
+ else:
74
+ reward_1, reward_2 = rewards_1[i].item(), rewards_2[i].item()
75
+
76
+ item_result = {
77
+ 'reward_1': reward_1,
78
+ 'reward_2': reward_2,
79
+ 'correct': reward_1 > reward_2,
80
+ 'equal': reward_1 == reward_2,
81
+ 'info': info
82
+ }
83
+ process_results.append(item_result)
84
+
85
+ print(f"Process {process_id} - Reward 1: {reward_1}, Reward 2: {reward_2}")
86
+ if reward_1 > reward_2:
87
+ process_correct += 1
88
+ if reward_1 == reward_2:
89
+ process_equal += 1
90
+
91
+ elif mode == 'ranking':
92
+ for item in batch_info:
93
+ rewards = inferencer.reward(item["generations"], item["prompt"])
94
+ predict_item = {
95
+ "id": item["id"],
96
+ "prompt": item["prompt"],
97
+ "rewards": rewards
98
+ }
99
+ process_results.append(predict_item)
100
+ # Put results in queue
101
+ if mode == 'pair':
102
+ result_queue.put({
103
+ 'process_id': process_id,
104
+ 'correct': process_correct,
105
+ 'equal': process_equal,
106
+ 'total': len(data_chunk),
107
+ 'results': process_results
108
+ })
109
+ elif mode == 'ranking':
110
+ result_queue.put({
111
+ 'process_id': process_id,
112
+ 'results': process_results
113
+ })
114
+
115
+ print(f"Process {process_id} completed: {process_correct}/{len(data_chunk)} correct, {process_equal}/{len(data_chunk)} equal")
116
+
117
+ def main(test_json, config_path=None, batch_size=8, num_processes=8, checkpoint_path=None, mode='pair'):
118
+
119
+ assert mode in ['pair', 'ranking'], "Mode must be either 'pair' or 'ranking'"
120
+ assert checkpoint_path is not None, "Checkpoint path must be provided for inference"
121
+
122
+ mp.set_start_method('spawn', force=True)
123
+
124
+ info_list = json.load(open(test_json, "r"))
125
+
126
+ print(f"Total items to process: {len(info_list)}")
127
+ # Split data into chunks for each process
128
+ chunk_size = math.ceil(len(info_list) / num_processes)
129
+ data_chunks = []
130
+ for i in range(num_processes):
131
+ start_idx = i * chunk_size
132
+ end_idx = min((i + 1) * chunk_size, len(info_list))
133
+ if start_idx < len(info_list):
134
+ chunk = info_list[start_idx:end_idx]
135
+ data_chunks.append(chunk)
136
+ print(f"Process {i}: {len(chunk)} items (indices {start_idx}-{end_idx-1})")
137
+
138
+ # Ensure we have the right number of non-empty chunks
139
+ actual_processes = len(data_chunks)
140
+ print(f"Using {actual_processes} processes")
141
+
142
+ # Create result queue and processes
143
+ result_queue = Queue()
144
+ processes = []
145
+
146
+ print("Starting processes...")
147
+ for i in range(actual_processes):
148
+ p = Process(target=worker_process, args=(i, data_chunks[i], config_path, checkpoint_path, batch_size, result_queue, mode))
149
+ p.start()
150
+ processes.append(p)
151
+
152
+ # Wait for all processes to complete and collect results
153
+ all_results = []
154
+ total_correct = 0
155
+ total_equal = 0
156
+ total_items = 0
157
+
158
+ print("Waiting for processes to complete...")
159
+ for i in range(actual_processes):
160
+ result = result_queue.get()
161
+ all_results.append(result)
162
+ if mode == 'pair':
163
+ total_correct += result['correct']
164
+ total_equal += result['equal']
165
+ total_items += result['total']
166
+
167
+ print(f"Process {result['process_id']} finished: {result['correct']}/{result['total']} correct, {result['equal']}/{result['total']} equal")
168
+
169
+ # Wait for all processes to join
170
+ for p in processes:
171
+ p.join()
172
+
173
+ if mode == 'pair':
174
+ aggregated_results = {
175
+ 'total_correct': total_correct,
176
+ 'total_equal': total_equal,
177
+ 'total_items': total_items,
178
+ 'accuracy': total_correct / total_items,
179
+ 'process_results': all_results
180
+ }
181
+ table = prettytable.PrettyTable()
182
+ table.field_names = ["Total Items", "Correct", "Equal", "Incorrect", "Accuracy (%)"]
183
+
184
+ incorrect = aggregated_results['total_items'] - aggregated_results['total_correct'] - aggregated_results['total_equal']
185
+ accuracy_percent = 100 * aggregated_results['total_correct'] / aggregated_results['total_items']
186
+
187
+ table.add_row([
188
+ aggregated_results['total_items'],
189
+ aggregated_results['total_correct'],
190
+ aggregated_results['total_equal'],
191
+ incorrect,
192
+ f"{accuracy_percent:.2f}"
193
+ ])
194
+ elif mode == 'ranking':
195
+ rank_acc = calc_rank_acc(info_list, all_results[0]['results'])
196
+ table = prettytable.PrettyTable()
197
+ table.field_names = ["Total Items", "Rank Accuracy (%)"]
198
+ table.add_row([len(info_list), f"{rank_acc * 100:.2f}"])
199
+
200
+ print(table)
201
+
202
+ if __name__ == "__main__":
203
+ fire.Fire(main)
generate/README.md ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Image Generation Module
2
+
3
+ This module is designed for generating images from text prompts using various pretrained diffusion models. It supports parallel generation across multiple GPUs and can be extended to include new models easily.
4
+
5
+ ## File Structure
6
+
7
+ - `gen_images_from_prompt.py`: The main script for running the image generation process. It reads prompts from a JSON file and handles command-line arguments.
8
+ - `generator.py`: Contains the core `Generator` class, which manages the model pipelines and distributes the generation tasks across different devices.
9
+ - `utils/pipelines.py`: Defines the configurations for all supported pretrained models. This is where you can add or modify model parameters.
10
+ - `utils/utils.py`: Contains helper functions for initializing `diffusers` pipelines and interacting with model APIs.
11
+
12
+ ## How to Use
13
+
14
+ To generate images, run the main script with the required arguments.
15
+
16
+ ### Basic Command
17
+
18
+ ```bash
19
+ python gen_images_from_prompt.py \
20
+ --json_path /path/to/your/prompts.json \
21
+ --out_dir /path/to/your/output_directory \
22
+ --pipeline_name sd_xl_pipe flux_schnell_pipe
23
+ ```
24
+
25
+ ### Command-Line Arguments
26
+
27
+ - `--json_path` (required): Path to a JSON file containing a list of prompts. Each item in the list should be an object with a `"caption"` key.
28
+
29
+ - For generating images according to real images, you should specify `"image_file"` which is the original image path, and `"aspect_ratio"` of this image. The specific height and width will be adjusted according to model's best practice resolution.
30
+
31
+ - For generating images from prompt only, you should specify `"save_name"`, `"height"` and `"width"`
32
+ **Example `prompts.json` format:**
33
+ ```json
34
+ [
35
+ {
36
+ "image_file": "1.jpg",
37
+ "caption": "A beautiful landscape painting of a mountain range at sunset.",
38
+ "aspect_ratio": 0.5,
39
+ },
40
+ {
41
+ "image_file": "2.jpg",
42
+ "caption": "A close-up photo of a red rose with water droplets.",
43
+ "aspect_ratio": 1.0,
44
+ },
45
+ {
46
+ "image_file": "3.jpg",
47
+ "caption": "An astronaut riding a horse on Mars, digital art.",
48
+ "aspect_ratio": 1.77,
49
+ }
50
+ ]
51
+ ```
52
+ - `--out_dir` (required): The root directory where generated images will be saved. A subdirectory will be created for each pipeline.
53
+ - `--pipeline_name` (required): One or more pipeline configuration names to use for generation. These names must correspond to the `PipelineParam` variable names defined in `utils/pipelines.py`.
54
+ - `--num_devices`: The number of GPU devices to use for generation. Defaults to `8`.
55
+ - `--batch_size`: The batch size per device. Defaults to `1`.
56
+ - `--num_machine`: The total number of machines used in a distributed setup. Defaults to `1`.
57
+ - `--machine_id`: The ID of the current machine in a distributed setup. Defaults to `0`.
58
+ - `--enable_availabel_check`: If set, the script will first run a quick check on a small batch to ensure each pipeline can be loaded and run without errors.
59
+ - `--reverse`: If set, the order of the specified pipelines will be reversed.
60
+
61
+ ## How to Add a New Model
62
+
63
+ You can easily add a new text-to-image model by configuring it in the `utils/pipelines.py` file.
64
+
65
+ 1. **Open `utils/pipelines.py`**.
66
+ 2. **Import `PipelineParam`** if it's not already imported.
67
+ 3. **Create a new `PipelineParam` instance** for your model. Define the following parameters:
68
+ - `pipeline_name`: The model's path on the Hugging Face Hub or a local directory.
69
+ - `generation_path`: The name of the subdirectory where the output images will be saved.
70
+ - `pipeline_type`: The type of pipeline, e.g., `'t2i'` (text-to-image) or `'t2v'` (text-to-video). Defaults to `'t2i'`.
71
+ - `pipe_init_kwargs`: A dictionary of arguments required for initializing the model pipeline (e.g., `{"torch_dtype": torch.float16}`).
72
+ - `generation_kwargs`: A dictionary of arguments for the generation process (e.g., `{"guidance_scale": 7.0, "num_inference_steps": 28}`).
73
+ - `base_resolution`: The base resolution the model was trained on (e.g., `1024`).
74
+ - `force_aspect_ratio`: Optionally force a specific aspect ratio (e.g., `1` for square images).
75
+
76
+ **Example:**
77
+
78
+ ```python
79
+ from pydantic import BaseModel, Field
80
+ import torch
81
+
82
+ class PipelineParam(BaseModel):
83
+ # ... (class definition)
84
+
85
+ # Add your new model configuration
86
+ my_new_model_pipe = PipelineParam(
87
+ pipeline_name='organization/my-cool-model',
88
+ generation_path=f'generation/my_cool_model',
89
+ pipe_init_kwargs={
90
+ "torch_dtype": torch.float16,
91
+ },
92
+ base_resolution=1024,
93
+ generation_kwargs={
94
+ "guidance_scale": 5.0,
95
+ "num_inference_steps": 30,
96
+ }
97
+ )
98
+ ```
99
+
100
+ 4. **Run the generation script** using the name of your new `PipelineParam` variable in the `--pipeline_name` argument.
101
+
102
+ ```bash
103
+ python gen_images_from_prompt.py --pipeline_name my_new_model_pipe ...
104
+ ```
generate/__init__.py ADDED
File without changes
generate/gen_images_from_prompt.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from generator import Generator
2
+ import json
3
+ import os
4
+ import torch
5
+ import gc
6
+ from utils.pipelines import *
7
+ import argparse
8
+
9
+ def parse_args():
10
+ parser = argparse.ArgumentParser(description="生成图片")
11
+ parser.add_argument(
12
+ "--json_path",
13
+ type=str,
14
+ help="json路径",
15
+ )
16
+ parser.add_argument(
17
+ "--out_dir",
18
+ type=str,
19
+ help="输出目录",
20
+ )
21
+ parser.add_argument("--num_devices", type=int, default=8, help="设备数量")
22
+ parser.add_argument("--batch_size", type=int, default=1, help="批量大小")
23
+ parser.add_argument("--num_machine", type=int, default=1, help="机器数量")
24
+ parser.add_argument("--machine_id", type=int, default=0, help="机器id")
25
+ parser.add_argument(
26
+ "--pipeline_name", type=str, nargs="+", default=None, help="pipeline名称"
27
+ )
28
+ parser.add_argument("--enable_availabel_check", action="store_true")
29
+ parser.add_argument("--reverse", action="store_true")
30
+ return parser.parse_args()
31
+
32
+
33
+ def main():
34
+ args = parse_args()
35
+ num_devices = args.num_devices
36
+ pipeline_params = [globals()[f"{name}_pipe"] for name in args.pipeline_name]
37
+
38
+ if args.reverse:
39
+ pipeline_params = pipeline_params[::-1]
40
+
41
+ # first check all pipeline
42
+ if args.enable_availabel_check:
43
+ print(f"Checking {len(pipeline_params)} pipelines")
44
+ for pipeline_param in pipeline_params:
45
+ generator = Generator(
46
+ pipe_name=pipeline_param.pipeline_name,
47
+ pipe_type=pipeline_param.pipeline_type,
48
+ pipe_init_kwargs=pipeline_param.pipe_init_kwargs,
49
+ num_devices=num_devices,
50
+ )
51
+
52
+ with open(args.json_path, "r") as f:
53
+ entries = json.load(f)
54
+ info_dict = entries[: args.batch_size]
55
+ generator.generate(
56
+ info_dict,
57
+ os.path.join(args.out_dir, pipeline_param.generation_path),
58
+ batch_size=args.batch_size,
59
+ num_processes=num_devices,
60
+ seed=42,
61
+ weight_dtype=pipeline_param.pipe_init_kwargs["torch_dtype"],
62
+ generation_kwargs=pipeline_param.generation_kwargs,
63
+ base_resolution=pipeline_param.base_resolution,
64
+ force_aspect_ratio=pipeline_param.force_aspect_ratio,
65
+ )
66
+ del generator
67
+ gc.collect()
68
+ torch.cuda.empty_cache()
69
+ print(f"Finished Checking {pipeline_param.pipeline_name}")
70
+
71
+ for pipeline_param in pipeline_params:
72
+ generator = Generator(
73
+ pipe_name=pipeline_param.pipeline_name,
74
+ pipe_type=pipeline_param.pipeline_type,
75
+ pipe_init_kwargs=pipeline_param.pipe_init_kwargs,
76
+ num_devices=num_devices,
77
+ )
78
+
79
+ with open(args.json_path, "r") as f:
80
+ entries = json.load(f)
81
+
82
+ for i in range(args.num_machine):
83
+ start_idx = i * len(entries) // args.num_machine
84
+ end_idx = (
85
+ (i + 1) * len(entries) // args.num_machine
86
+ if i != args.num_machine - 1
87
+ else len(entries)
88
+ )
89
+ if i == args.machine_id:
90
+ info_dict = entries[start_idx:end_idx]
91
+
92
+ info_dict = sorted(info_dict, key=lambda x: x["aspect_ratio"])
93
+
94
+ print(f"Generating {len(info_dict)} images")
95
+ generator.generate(
96
+ info_dict,
97
+ os.path.join(args.out_dir, pipeline_param.generation_path),
98
+ batch_size=args.batch_size,
99
+ num_processes=num_devices,
100
+ seed=42,
101
+ weight_dtype=pipeline_param.pipe_init_kwargs["torch_dtype"],
102
+ generation_kwargs=pipeline_param.generation_kwargs,
103
+ base_resolution=pipeline_param.base_resolution,
104
+ force_aspect_ratio=pipeline_param.force_aspect_ratio,
105
+ )
106
+
107
+ print(f"Finished generating {pipeline_param.pipeline_name}")
108
+
109
+ for pipeline in generator.pipelines:
110
+ pipeline.to("cpu")
111
+ del generator
112
+ torch.cuda.empty_cache()
113
+ gc.collect()
114
+
115
+
116
+ if __name__ == "__main__":
117
+ main()
generate/generator.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import inspect
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+ from utils.utils import init_multiple_pipelines
7
+ from concurrent.futures import ThreadPoolExecutor, as_completed
8
+
9
+ Image.MAX_IMAGE_PIXELS = None
10
+
11
+
12
+ class Generator:
13
+ def __init__(
14
+ self, pipe_name, pipe_type, pipe_init_kwargs, num_devices, device_id=None
15
+ ):
16
+ self.pipe_names = pipe_name
17
+ self.pipe_type = pipe_type
18
+ self.pipe_init_kwargs = pipe_init_kwargs
19
+ self.pipelines = init_multiple_pipelines(
20
+ pipe_name, pipe_init_kwargs, num_devices, device_id
21
+ )
22
+
23
+ def generate_imgs(
24
+ self,
25
+ num_device,
26
+ batch_size,
27
+ generation_path,
28
+ info_dict,
29
+ pipeline,
30
+ device_id,
31
+ weight_dtype,
32
+ seed,
33
+ base_resolution,
34
+ force_aspect_ratio,
35
+ generation_kwargs,
36
+
37
+ ):
38
+
39
+ torch.cuda.set_device(f"cuda:{device_id%num_device}")
40
+ device = torch.device(f"cuda:{device_id%num_device}")
41
+
42
+ num_prompts_per_device = len(info_dict) // num_device
43
+ start_idx = device_id * num_prompts_per_device
44
+ end_idx = (
45
+ start_idx + num_prompts_per_device
46
+ if device_id != (num_device - 1)
47
+ else len(info_dict)
48
+ )
49
+
50
+ device_info_dict = info_dict[start_idx:end_idx]
51
+
52
+ print(f"Device {device} generating for prompts {start_idx} to {end_idx-1}")
53
+
54
+ print("## Prepare generation dataset")
55
+
56
+ total_batches = len(device_info_dict) // batch_size + (
57
+ 1 if len(device_info_dict) % batch_size != 0 else 0
58
+ )
59
+ for batch_idx in tqdm(
60
+ range(total_batches), desc="Pipeline: " + self.pipe_names
61
+ ):
62
+ batch_info_dict = device_info_dict[
63
+ batch_idx * batch_size : (batch_idx + 1) * batch_size
64
+ ]
65
+ save_paths = []
66
+ for info_dict in batch_info_dict:
67
+ if info_dict["image_file"] is not None:
68
+ save_paths.append(
69
+ os.path.join(generation_path, info_dict["image_file"][:-4] + ".png")
70
+ )
71
+ else:
72
+ save_paths.append(
73
+ os.path.join(generation_path, info_dict["save_name"] + ".png")
74
+ )
75
+
76
+ exists_idx = []
77
+ for i, save_path in enumerate(save_paths):
78
+ if os.path.exists(save_path):
79
+ exists_idx.append(i)
80
+
81
+ batch_info_dict = [
82
+ batch_info_dict[i]
83
+ for i in range(len(batch_info_dict))
84
+ if i not in exists_idx
85
+ ]
86
+ if len(batch_info_dict) == 0:
87
+ continue
88
+
89
+ batch_prompts = [info_dict["caption"] for info_dict in batch_info_dict]
90
+ batch_image_file = [
91
+ info_dict["image_file"] for info_dict in batch_info_dict
92
+ ]
93
+ if batch_image_file[0] is not None:
94
+ try:
95
+ batch_image_sizes = [
96
+ Image.open(image_file).size for image_file in batch_image_file
97
+ ]
98
+ except:
99
+ batch_image_sizes = None
100
+ else:
101
+ batch_image_sizes = [
102
+ (batch_info_dict[i]["width"], batch_info_dict[i]["height"])
103
+ for i in range(len(batch_info_dict))
104
+ ]
105
+
106
+ if batch_image_sizes is None:
107
+ aspect_ratios = [
108
+ info_dict["aspect_ratio"] for info_dict in batch_info_dict
109
+ ]
110
+ else:
111
+ aspect_ratios = [size[0] / size[1] for size in batch_image_sizes]
112
+
113
+ if force_aspect_ratio:
114
+ height = int(base_resolution / force_aspect_ratio // 64 * 64)
115
+ width = int(base_resolution * force_aspect_ratio // 64 * 64)
116
+ else:
117
+ # 根据aspect_ratios调整base_resolution, 得到height和width, 保证调整后的乘积大概等于base_resolution**2
118
+ height = int(base_resolution / aspect_ratios[0] ** (0.5) // 64 * 64)
119
+ width = int(base_resolution * aspect_ratios[0] ** (0.5) // 64 * 64)
120
+ generation_kwargs.update({"height": height, "width": width})
121
+
122
+ generator = torch.Generator().manual_seed(seed + batch_idx)
123
+
124
+ pipeline_signature = inspect.signature(pipeline)
125
+ pipeline_params = pipeline_signature.parameters.keys()
126
+
127
+ if 'height' not in pipeline_params:
128
+ generation_kwargs.pop('height', None)
129
+ print(f"Warning: Pipeline does not support 'height' parameter, removing from kwargs")
130
+ if 'width' not in pipeline_params:
131
+ generation_kwargs.pop('width', None)
132
+ print(f"Warning: Pipeline does not support 'width' parameter, removing from kwargs")
133
+
134
+ try:
135
+ outputs = pipeline(
136
+ prompt=batch_prompts, generator=generator, **generation_kwargs
137
+ )
138
+ except Exception as e:
139
+ print(e)
140
+ continue
141
+ if self.pipe_type == "t2i":
142
+ images = outputs.images
143
+ elif self.pipe_type == "t2v":
144
+ images = outputs.frames[0]
145
+
146
+ for img_idx, (img, prompt, image_file, info_dict) in enumerate(
147
+ zip(images, batch_prompts, batch_image_file, batch_info_dict)
148
+ ):
149
+ if image_file is None:
150
+ img_path = os.path.join(
151
+ generation_path, info_dict["save_name"] + ".png"
152
+ )
153
+ else:
154
+ img_path = generation_path + image_file[:-4] + ".png"
155
+
156
+ if not os.path.exists(os.path.dirname(img_path)):
157
+ os.makedirs(os.path.dirname(img_path), exist_ok=True)
158
+ img.save(img_path)
159
+ if image_file is None:
160
+ text_path = os.path.join(
161
+ generation_path, info_dict["save_name"] + ".txt"
162
+ )
163
+ else:
164
+ text_path = generation_path + image_file[:-4] + ".txt"
165
+ try:
166
+ with open(text_path, "w") as f:
167
+ f.write(prompt)
168
+ f.write("\n")
169
+ f.write(
170
+ image_file
171
+ if image_file is not None
172
+ else info_dict["save_name"]
173
+ )
174
+ except:
175
+ pass
176
+ return True
177
+
178
+ def generate(
179
+ self,
180
+ info_dict,
181
+ generation_path,
182
+ num_processes,
183
+ batch_size,
184
+ weight_dtype,
185
+ seed,
186
+ generation_kwargs,
187
+ base_resolution,
188
+ force_aspect_ratio,
189
+ ):
190
+
191
+ with ThreadPoolExecutor(max_workers=num_processes) as executor:
192
+ futures = [
193
+ executor.submit(
194
+ self.generate_imgs,
195
+ num_processes,
196
+ batch_size,
197
+ generation_path,
198
+ info_dict,
199
+ self.pipelines[device_id],
200
+ device_id,
201
+ weight_dtype,
202
+ seed,
203
+ base_resolution,
204
+ force_aspect_ratio,
205
+ generation_kwargs,
206
+ )
207
+ for device_id in range(num_processes)
208
+ ]
209
+
210
+ for future in as_completed(futures):
211
+ print(f"Task completed: {future.result()}")
generate/utils/__init__.py ADDED
File without changes
generate/utils/pipelines.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from pydantic import BaseModel, Field
3
+ from typing import Optional, Dict, Any
4
+
5
+ class PipelineParam(BaseModel):
6
+ pipeline_name: str
7
+ generation_path: str
8
+ pipeline_type: str = 't2i'
9
+ pipe_init_kwargs: Dict[str, Any] = Field(default_factory=dict)
10
+ generation_kwargs: Dict[str, Any] = Field(default_factory=dict)
11
+ base_resolution: int = 1024
12
+ force_aspect_ratio: Optional[int] = None
13
+
14
+ flux_dev_pipe = PipelineParam(
15
+ pipeline_name='pretrained_models/FLUX.1-dev',
16
+ generation_path=f'generation/flux_dev',
17
+ pipe_init_kwargs={
18
+ "torch_dtype": torch.bfloat16,
19
+ },
20
+ base_resolution=1024,
21
+ generation_kwargs={
22
+ "guidance_scale": 3.5,
23
+ "num_inference_steps": 28,
24
+ "max_sequence_length": 512,
25
+ }
26
+ )
27
+
28
+ flux_schnell_pipe = PipelineParam(
29
+ pipeline_name='pretrained_models/FLUX.1-schnell',
30
+ generation_path=f'generation/flux_schnell',
31
+ pipe_init_kwargs={
32
+ "torch_dtype": torch.bfloat16,
33
+ },
34
+ base_resolution=1024,
35
+ generation_kwargs={
36
+ "guidance_scale": 3.5,
37
+ "num_inference_steps": 4,
38
+ }
39
+ )
40
+
41
+
42
+ sd3_medium_pipe = PipelineParam(
43
+ pipeline_name='pretrained_models/stable-diffusion-3-medium-diffusers',
44
+ generation_path=f'generation/sd3_medium',
45
+ pipe_init_kwargs={
46
+ "torch_dtype": torch.float16,
47
+ },
48
+ base_resolution=1024,
49
+ generation_kwargs={
50
+ "guidance_scale": 7.0,
51
+ "num_inference_steps": 28,
52
+ }
53
+ )
54
+
55
+ sd_xl_pipe = PipelineParam(
56
+ pipeline_name='pretrained_models/stable-diffusion-xl-base-1.0',
57
+ generation_path=f'generation/sd_xl',
58
+ pipe_init_kwargs={
59
+ "torch_dtype": torch.float16,
60
+ },
61
+ base_resolution=1024,
62
+ generation_kwargs={
63
+ "guidance_scale": 5,
64
+ "num_inference_steps": 50,
65
+ }
66
+ )
67
+
68
+ sd_1_5_pipe = PipelineParam(
69
+ pipeline_name='pretrained_models/stable-diffusion-v1-5',
70
+ generation_path=f'generation/sd_1_5',
71
+ pipe_init_kwargs={
72
+ "torch_dtype": torch.float16,
73
+ },
74
+ base_resolution=512,
75
+ generation_kwargs={
76
+ }
77
+ )
78
+
79
+ vq_diffusion_pipe = PipelineParam(
80
+ pipeline_name='pretrained_models/vq-diffusion-ithq',
81
+ generation_path=f'generation/vq_diffusion',
82
+ pipe_init_kwargs={
83
+ "torch_dtype": torch.float16,
84
+ },
85
+ base_resolution=256,
86
+ generation_kwargs={}
87
+ )
88
+
89
+ sd_2_pipe = PipelineParam(
90
+ pipeline_name='pretrained_models/stable-diffusion-2',
91
+ generation_path=f'generation/sd_2',
92
+ pipe_init_kwargs={
93
+ "torch_dtype": torch.float16,
94
+ },
95
+ base_resolution=512,
96
+ force_aspect_ratio=1,
97
+ )
98
+
99
+ sd_1_1_pipe = PipelineParam(
100
+ pipeline_name='pretrained_models/stable-diffusion-v1-1',
101
+ generation_path=f'generation/sd_1_1',
102
+ pipe_init_kwargs={"torch_dtype": torch.float16,},
103
+ base_resolution=512,
104
+ force_aspect_ratio=1,
105
+ )
106
+
107
+ sd_1_4_pipe = PipelineParam(
108
+ pipeline_name='pretrained_models/stable-diffusion-v1-4',
109
+ generation_path=f'generation/sd_1_4',
110
+ pipe_init_kwargs={
111
+ "torch_dtype": torch.float16,
112
+ },
113
+ base_resolution=512,
114
+ force_aspect_ratio=1,
115
+ )
116
+
117
+ sd_2_1_pipe = PipelineParam(
118
+ pipeline_name='pretrained_models/stable-diffusion-2-1-base',
119
+ generation_path=f'generation/sd_2_1',
120
+ pipe_init_kwargs={
121
+ "torch_dtype": torch.float16,
122
+ },
123
+ base_resolution=512,
124
+ force_aspect_ratio=1,
125
+ )
126
+
127
+ openjourney_pipe = PipelineParam(
128
+ pipeline_name='pretrained_models/openjourney',
129
+ generation_path=f'generation/openjourney',
130
+ pipe_init_kwargs={
131
+ "torch_dtype": torch.float16,
132
+ },
133
+ base_resolution=512,
134
+ force_aspect_ratio=1,
135
+ )
136
+
137
+ playground_v2_5_pipe = PipelineParam(
138
+ pipeline_name='pretrained_models/playground-v2.5-1024px-aesthetic',
139
+ generation_path=f'generation/playground_v_2_5',
140
+ pipe_init_kwargs={
141
+ "torch_dtype": torch.float16,
142
+ },
143
+ base_resolution=1024,
144
+ )
145
+
146
+ versatile_pipe = PipelineParam(
147
+ pipeline_name='pretrained_models/versatile-diffusion',
148
+ generation_path=f'generation/versatile',
149
+ pipe_init_kwargs={
150
+ "torch_dtype": torch.float16,
151
+ },
152
+ base_resolution=512,
153
+ force_aspect_ratio=1,
154
+ )
155
+
156
+ glide_pipe = PipelineParam(
157
+ pipeline_name='pretrained_models/glide-base',
158
+ generation_path=f'generation/glide',
159
+ pipe_init_kwargs={
160
+ "torch_dtype": torch.float16,
161
+ },
162
+ base_resolution=512,
163
+ force_aspect_ratio=1,
164
+ )
165
+
166
+ sd_3_5_medium_pipe = PipelineParam(
167
+ pipeline_name='stabilityai/stable-diffusion-3.5-medium',
168
+ generation_path=f'generation/sd_3_5_medium',
169
+ pipe_init_kwargs={
170
+ "torch_dtype": torch.bfloat16,
171
+ },
172
+ base_resolution=1024,
173
+ generation_kwargs={
174
+ "num_inference_steps": 40,
175
+ "guidance_scale": 4.5,
176
+ }
177
+ )
178
+
179
+ sd_3_5_large_pipe = PipelineParam(
180
+ pipeline_name='stabilityai/stable-diffusion-3.5-large',
181
+ generation_path=f'generation/sd_3_5_large',
182
+ pipe_init_kwargs={
183
+ "torch_dtype": torch.bfloat16,
184
+ },
185
+ base_resolution=1024,
186
+ generation_kwargs={
187
+ "num_inference_steps": 28,
188
+ "guidance_scale": 3.5,
189
+ }
190
+ )
191
+
192
+ kolors_pipe = PipelineParam(
193
+ pipeline_name='pretrained_models/Kolors-diffusers',
194
+ generation_path=f'generation/kolors',
195
+ pipe_init_kwargs={
196
+ "torch_dtype": torch.float16,
197
+ 'variant': 'fp16',
198
+ },
199
+ base_resolution=1024,
200
+ generation_kwargs={
201
+ "num_inference_steps": 50,
202
+ "guidance_scale": 5.0,
203
+ }
204
+ )
205
+
206
+ cogview4_pipe = PipelineParam(
207
+ pipeline_name='pretrained_models/CogView4-6B',
208
+ generation_path=f'generation/cogview4',
209
+ pipe_init_kwargs={
210
+ "torch_dtype": torch.bfloat16,
211
+ },
212
+ base_resolution=1024,
213
+ generation_kwargs={
214
+ "num_inference_steps": 50,
215
+ "guidance_scale": 3.5,
216
+ }
217
+ )
218
+
219
+ pixart_sigma_pipe = PipelineParam(
220
+ pipeline_name='pretrained_models/PixArt-Sigma-XL-2-1024-MS',
221
+ generation_path=f'generation/pixart_sigma',
222
+ pipeline_type='t2i',
223
+ pipe_init_kwargs={
224
+ "torch_dtype": torch.bfloat16,
225
+ },
226
+ base_resolution=1024,
227
+ )
228
+
229
+ hunyuanvideo_pipe = PipelineParam(
230
+ pipeline_name='pretrained_models/hunyuanvideo_diffusers',
231
+ generation_path=f'generation/hunyuanvideo',
232
+ pipe_init_kwargs={
233
+ "torch_dtype": torch.bfloat16,
234
+ },
235
+ base_resolution=1024,
236
+ pipeline_type='t2v',
237
+ generation_kwargs={
238
+ "num_inference_steps": 30,
239
+ "num_frames": 1,
240
+ }
241
+ )
242
+
243
+ hunyuandit_pipe = PipelineParam(
244
+ pipeline_name='pretrained_models/HunyuanDiT-v1.2-Diffusers',
245
+ generation_path=f'generation/hunyuandit',
246
+ pipe_init_kwargs={
247
+ "torch_dtype": torch.float16,
248
+ },
249
+ base_resolution=1024,
250
+ pipeline_type='t2i',
251
+ generation_kwargs={
252
+ }
253
+ )
254
+
255
+ # API models
256
+ # Fal.ai
257
+ flux_pro_v1_1_ultr_pipe = PipelineParam(
258
+ pipeline_name='fal-ai/flux-pro/v1.1-ultra',
259
+ generation_path=f'generation/flux_pro_v1_1_ultra',
260
+ base_resolution=1024,
261
+ generation_kwargs={
262
+ "enable_safety_checker": False,
263
+ "num_images": 1,
264
+ # "aspect_ratio": "1:1",
265
+ "output_format": "jpeg",
266
+ "safety_tolerance": 5,
267
+ }
268
+ )
269
+
270
+ recraftv3_pipe = PipelineParam(
271
+ pipeline_name='fal-ai/recraft-v3',
272
+ generation_path=f'generation/recraftv3',
273
+ base_resolution=1024,
274
+ generation_kwargs={
275
+ "enable_safety_checker": False,
276
+ "num_images": 1,
277
+ # "aspect_ratio": "1:1",
278
+ "output_format": "jpeg",
279
+ "safety_tolerance": 5,
280
+ }
281
+ )
282
+
generate/utils/utils.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ try:
3
+ import fal_client
4
+ except:
5
+ fal_client = None
6
+
7
+ from diffusers import AutoPipelineForText2Image, HunyuanVideoPipeline, DiffusionPipeline
8
+ import json
9
+ import diffusers
10
+ from functools import partial
11
+ import os
12
+ # export FAL_KEY="YOUR_API_KEY"
13
+ os.environ['FAL_KEY'] = 'YOUR_API_KEY'
14
+
15
+ def init_multiple_pipelines(pipe_name, pipe_init_kwargs, num_devices, device_id=None):
16
+ pipelines_dict = []
17
+
18
+ if device_id is not None:
19
+ assert num_devices == 1
20
+
21
+ for i in range(num_devices):
22
+ actual_device_id = device_id if device_id is not None else i
23
+ try:
24
+ pipeline = AutoPipelineForText2Image.from_pretrained(pipe_name, **pipe_init_kwargs).to(f'cuda:{actual_device_id}')
25
+ except Exception as e:
26
+ # try:
27
+ config = json.load(open(os.path.join(pipe_name, 'model_index.json')))
28
+ class_name_str = config['_class_name']
29
+ pipeline_class = getattr(diffusers, class_name_str)
30
+ pipeline = pipeline_class.from_pretrained(pipe_name, **pipe_init_kwargs).to(f'cuda:{actual_device_id}')
31
+ # except Exception as ew:
32
+ # print(e)
33
+ # pipeline = DiffusionPipeline.from_pretrained(pipe_name, **pipe_init_kwargs).to(f'cuda:{actual_device_id}')
34
+ pipelines_dict.append(pipeline)
35
+ return pipelines_dict
36
+
37
+
38
+ def init_pipeline_from_names(pipe_names, weight_dtype):
39
+ pipelines_dict = {}
40
+ for name in pipe_names:
41
+ pipeline = AutoPipelineForText2Image.from_pretrained(name, torch_dtype=weight_dtype)
42
+ pipelines_dict[name] = pipeline
43
+ return pipelines_dict
44
+
45
+
46
+ def on_queue_update(update):
47
+ if isinstance(update, fal_client.InProgress):
48
+ for log in update.logs:
49
+ print(log["message"])
50
+
51
+ def gen_with_api(pipe_names, generation_kwargs):
52
+ result = fal_client.subscribe(
53
+ pipe_names,
54
+ arguments=generation_kwargs,
55
+ with_logs=True,
56
+ on_queue_update=on_queue_update,
57
+ )
58
+ return result
hpsv3/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .inference import HPSv3RewardInferencer
hpsv3/cohp/__init__.py ADDED
File without changes
hpsv3/cohp/cohp_all.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from generator import Generator
3
+ import json
4
+ import os
5
+ import torch
6
+ import gc
7
+ from utils_cohp.pipelines import *
8
+ from utils_cohp.image2image_pipeline import Image2ImagePipeline
9
+ import argparse
10
+ from ..inference import HPSv3RewardInferencer
11
+ import random
12
+ from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer
13
+ import ImageReward as RM
14
+ from PIL import Image
15
+ from transformers import AutoProcessor, AutoModel
16
+
17
+ def initialize_model(device, cp):
18
+ model_dict = {}
19
+ model, preprocess_train, preprocess_val = create_model_and_transforms(
20
+ 'ViT-H-14',
21
+ 'laion2B-s32B-b79K',
22
+ precision='amp',
23
+ device=device,
24
+ jit=False,
25
+ force_quick_gelu=False,
26
+ force_custom_text=False,
27
+ force_patch_dropout=False,
28
+ force_image_size=None,
29
+ pretrained_image=False,
30
+ image_mean=None,
31
+ image_std=None,
32
+ light_augmentation=True,
33
+ aug_cfg={},
34
+ output_dict=True,
35
+ with_score_predictor=False,
36
+ with_region_predictor=False
37
+ )
38
+
39
+ checkpoint = torch.load(cp, map_location=device, weights_only=False)
40
+ model.load_state_dict(checkpoint['state_dict'])
41
+ model = model.to(device)
42
+ model.eval()
43
+ tokenizer = get_tokenizer('ViT-H-14')
44
+
45
+ model_dict['model'] = model
46
+ model_dict['preprocess_val'] = preprocess_val
47
+ return model_dict, tokenizer
48
+
49
+ def score_hpsv2_batch(model_dict, tokenizer, device, img_paths: list, prompts: list) -> list:
50
+ model = model_dict['model']
51
+ preprocess_val = model_dict['preprocess_val']
52
+ # 批量处理图片
53
+ images = [preprocess_val(Image.open(p)).unsqueeze(0) for p in img_paths]
54
+ images = torch.cat(images, dim=0).to(device=device)
55
+ texts = tokenizer(prompts).to(device=device)
56
+ with torch.no_grad():
57
+ outputs = model(images, texts)
58
+ image_features, text_features = outputs["image_features"], outputs["text_features"]
59
+ logits_per_image = image_features @ text_features.T
60
+ hps_scores = torch.diagonal(logits_per_image).cpu()
61
+ return hps_scores
62
+ def pickscorecalc_probs(model,processor_pickscore,prompt, images, device):
63
+
64
+ # preprocess
65
+ image_inputs = processor_pickscore(
66
+ images=images,
67
+ padding=True,
68
+ truncation=True,
69
+ max_length=77,
70
+ return_tensors="pt",
71
+ ).to(device)
72
+
73
+ text_inputs = processor_pickscore(
74
+ text=prompt,
75
+ padding=True,
76
+ truncation=True,
77
+ max_length=77,
78
+ return_tensors="pt",
79
+ ).to(device)
80
+
81
+
82
+ with torch.no_grad():
83
+ # embed
84
+ image_embs = model.get_image_features(**image_inputs)
85
+ image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
86
+
87
+ text_embs = model.get_text_features(**text_inputs)
88
+ text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)
89
+
90
+ # score
91
+ scores = text_embs @ image_embs.T
92
+
93
+ return scores
94
+
95
+ def generate_images(reward_type, prompt, index, pipeline_params, di_pipeline, inferencer, out_dir='cohp_output', num_rounds=5, strength=0.8, device='cuda:1'):
96
+ os.makedirs(out_dir, exist_ok=True)
97
+ os.makedirs(os.path.join(out_dir, 'result_json'), exist_ok=True)
98
+ batch_size = 2 # 设置batch大小
99
+
100
+ results = [] # 用于保存每个 prompt 的最终结果
101
+
102
+
103
+ info_dict = {
104
+ 'caption': prompt,
105
+ 'width': 1024,
106
+ 'height': 1024,
107
+ 'aspect_ratio': 1,
108
+ 'save_name': f"{index}_origin",
109
+ }
110
+ di_score_pipelines = {} # 用于存储 pipeline 的平均分数
111
+
112
+ # 中间结果记录结构:用于保存每一轮图像路径和分数
113
+ intermediate_results_sample_preference = []
114
+ intermediate_results_model_preference = []
115
+
116
+ # 遍历 pipeline 参数
117
+ for pipeline_param in pipeline_params:
118
+
119
+ name = di_pipeline[pipeline_param]
120
+ generator = Generator(
121
+ device = device,
122
+ pipe_name=pipeline_param.pipeline_name,
123
+ pipe_type=pipeline_param.pipeline_type,
124
+ pipe_init_kwargs=pipeline_param.pipe_init_kwargs,
125
+ )
126
+ image_paths = generator.generate_imgs(
127
+ info_dict = info_dict,
128
+ generation_path = os.path.join(out_dir, pipeline_param.generation_path),
129
+ batch_size=batch_size,
130
+ device = device,
131
+ seed=random.randint(0, 75859066837),
132
+ weight_dtype=pipeline_param.pipe_init_kwargs["torch_dtype"],
133
+ generation_kwargs=pipeline_param.generation_kwargs,
134
+ )
135
+
136
+ # 对生成的图像进行评分
137
+ score_list = []
138
+ for image_path in image_paths:
139
+ if reward_type == 'hpsv2':
140
+ score = score_hpsv2_batch(model_dict, tokenizer, device, [image_path], [prompt])
141
+ score = score.item()
142
+ elif reward_type == 'hpsv3':
143
+ score = inferencer.reward([image_path], [prompt]).cpu().detach()
144
+ score = score[0][0].item()
145
+ elif reward_type == 'imagereward':
146
+ score = inferencer.score(prompt, [image_path])
147
+ elif reward_type == 'pickscore':
148
+ score = pickscorecalc_probs(inferencer, processor_pickscore, prompt, [Image.open(image_path)],device)[0][0].item()
149
+ print(f"PickScore for {image_path}: {score}")
150
+ else:
151
+ raise ValueError("Unsupported reward type. Choose 'hpsv2', 'hpsv3', or 'imagereward'.")
152
+ score_list.append(score)
153
+
154
+ average = sum(score_list) / len(score_list)
155
+ di_score_pipelines[name] = average
156
+ # 保存中间步骤的图像路径和分数
157
+ intermediate_results_model_preference.append({
158
+ 'pipeline': name,
159
+ 'image_paths': image_paths, # 所有生成的图片路径
160
+ 'scores': score_list, # 每轮的得分列表
161
+ 'max_image_path': image_paths[score_list.index(max(score_list))], # 当前轮得分最高的图片路径
162
+ 'max_score': max(score_list) # 当前轮得分最高的分数
163
+ })
164
+
165
+ # 清理生成器资源
166
+ generator.pipelines.to("cpu")
167
+ del generator
168
+ torch.cuda.empty_cache()
169
+ gc.collect()
170
+
171
+ # 选择得分最高的 pipeline 和对应的图片
172
+ max_key = max(di_score_pipelines, key=di_score_pipelines.get)
173
+ max_index = score_list.index(max(score_list))
174
+ image_path_chosen = image_paths[max_index] # 首轮选择的最佳图片
175
+
176
+ # 多轮优化循环
177
+
178
+ for round_num in range(num_rounds):
179
+ if round_num == 3 or round_num == 4:
180
+ strength = 0.5
181
+ i2ipipeline = Image2ImagePipeline(max_key)
182
+ images = i2ipipeline.generate_image(
183
+ prompt=prompt,
184
+ image_path=image_path_chosen,
185
+ strength=strength,
186
+ batch_size=4,
187
+ save_prefix=f'{index}_{max_key}_image2image_round{round_num + 1}',
188
+ output_dir=out_dir
189
+ )
190
+
191
+ score_list = []
192
+ for image_path in images:
193
+ if reward_type == 'hpsv2':
194
+ score = score_hpsv2_batch(model_dict, tokenizer, device, [image_path], [prompt])
195
+ score = score.item()
196
+ elif reward_type == 'hpsv3':
197
+ score = inferencer.reward([image_path], [prompt]).cpu().detach()
198
+ score = score[0][0].item()
199
+ elif reward_type == 'imagereward':
200
+ score = inferencer.score(prompt, [image_path])
201
+ elif reward_type == 'pickscore':
202
+ score = pickscorecalc_probs(inferencer, processor_pickscore, prompt, [Image.open(image_path)],device)[0][0].item()
203
+ print(f"PickScore for {image_path}: {score}")
204
+ else:
205
+ raise ValueError("Unsupported reward type. Choose 'hpsv2', 'hpsv3', or 'imagereward'.")
206
+ score_list.append(score)
207
+
208
+ intermediate_results_sample_preference.append({
209
+ 'round': round_num + 1,
210
+ 'image_paths': images, # 所有生成的图片路径
211
+ 'scores': score_list, # 每轮的得分列表
212
+ 'max_image_path': images[score_list.index(max(score_list))], # 当前轮得分最高的图片路径
213
+ 'max_score': max(score_list) # 当前轮得分最高的分数
214
+ })
215
+
216
+ # 更新图片选择
217
+ max_index = score_list.index(max(score_list))
218
+ image_path_chosen = images[max_index]
219
+
220
+
221
+
222
+ # 最终结果保存
223
+ results.append({
224
+ 'prompt': prompt,
225
+ 'model_preference_image_chosen': image_path_chosen,
226
+ "model_preference_info": intermediate_results_model_preference, # 包含所有中间结果
227
+ 'best_image_path': image_path_chosen,
228
+ 'best_model': max_key,
229
+ 'score': max(score_list),
230
+ 'sample_preference_intermediate_results': intermediate_results_sample_preference, # 包含所有中间结果
231
+ })
232
+ with open(os.path.join(out_dir, 'result_json',f'{index}.json'),'w',encoding='utf-8') as f:
233
+ json.dump(results,f,ensure_ascii=False, indent=4)
234
+
235
+ return results
236
+
237
+ if __name__ == "__main__":
238
+ parser = argparse.ArgumentParser(description="Image Generation Script")
239
+ parser.add_argument('--prompt', type=str, required=True, help='The prompt for image generation')
240
+ parser.add_argument('--index', type=str, required=True, help='Index for saving results')
241
+ parser.add_argument('--device', type=str, default='cuda:1', help='Device to run the model on')
242
+ parser.add_argument('--reward_model', type=str, default='hpsv3', help='Reward model to use (hpsv2 or hpsv3 or pickscore or imagereward)')
243
+
244
+ args = parser.parse_args()
245
+ output_dir = f"cohp_output_{args.reward_model}"
246
+
247
+ os.makedirs(output_dir,exist_ok=True)
248
+ if args.reward_model == 'hpsv2':
249
+
250
+ inferencer = initialize_model(args.device, 'pretrained_models/HPS_v2.1_compressed.pt')
251
+ model_dict, tokenizer = inferencer
252
+ elif args.reward_model == 'hpsv3':
253
+ dtype = torch.bfloat16
254
+ inferencer = HPSv3RewardInferencer(device=args.device, dtype=dtype)
255
+ elif args.reward_model == 'imagereward':
256
+ inferencer = RM.load("ImageReward-v1.0").to(args.device)
257
+ elif args.reward_model == 'pickscore':
258
+ processor_name_or_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
259
+ model_pretrained_name_or_path = "yuvalkirstain/PickScore_v1"
260
+ processor_pickscore = AutoProcessor.from_pretrained(processor_name_or_path)
261
+ inferencer = AutoModel.from_pretrained(model_pretrained_name_or_path).eval().to(args.device)
262
+ else:
263
+ raise ValueError("Unsupported reward model. Choose 'hpsv2', 'hpsv3', or 'imagereward'.")
264
+ pipeline_params = [
265
+ flux_dev_pipe,
266
+ kolors_pipe,
267
+ sd3_medium_pipe,
268
+ playground_v2_5_pipe
269
+ ]
270
+
271
+ di_score_pipelines={}
272
+ di_pipeline = {
273
+ flux_dev_pipe:'flux',
274
+ kolors_pipe:'kolors',
275
+ sd3_medium_pipe:'sd3',
276
+ playground_v2_5_pipe:'playground_v2_5'
277
+
278
+ }
279
+
280
+ results = generate_images(
281
+ args.reward_model,
282
+ args.prompt,
283
+ args.index,
284
+ pipeline_params,
285
+ di_pipeline,
286
+ inferencer,
287
+ out_dir=output_dir,
288
+ num_rounds=4,
289
+ strength=0.8,
290
+ device=args.device)
hpsv3/cohp/generator.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import inspect
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+ from utils_cohp.utils import init_pipelines
7
+
8
+ Image.MAX_IMAGE_PIXELS = None
9
+
10
+
11
+ class Generator:
12
+ def __init__(
13
+ self, pipe_name, pipe_type, pipe_init_kwargs, device=None
14
+ ):
15
+ self.pipe_names = pipe_name
16
+ self.pipe_type = pipe_type
17
+ self.pipe_init_kwargs = pipe_init_kwargs
18
+ self.pipelines = init_pipelines(
19
+ pipe_name, pipe_init_kwargs, device
20
+ )
21
+
22
+ def generate_imgs(
23
+ self,
24
+ batch_size,
25
+ generation_path,
26
+ info_dict,
27
+ device,
28
+ weight_dtype,
29
+ seed,
30
+ generation_kwargs,
31
+
32
+ ):
33
+
34
+ torch.cuda.set_device(device)
35
+ device = torch.device(device)
36
+ generator = torch.Generator().manual_seed(seed)
37
+
38
+ pipeline_signature = inspect.signature(self.pipelines)
39
+ pipeline_params = pipeline_signature.parameters.keys()
40
+
41
+ if 'height' not in pipeline_params:
42
+ generation_kwargs.pop('height', None)
43
+ print(f"Warning: Pipeline does not support 'height' parameter, removing from kwargs")
44
+ if 'width' not in pipeline_params:
45
+ generation_kwargs.pop('width', None)
46
+ print(f"Warning: Pipeline does not support 'width' parameter, removing from kwargs")
47
+
48
+
49
+ outputs = self.pipelines(
50
+ prompt=info_dict['caption'], generator=generator,num_images_per_prompt = batch_size, **generation_kwargs
51
+ )
52
+ if self.pipe_type == "t2i":
53
+ images = outputs.images
54
+ elif self.pipe_type == "t2v":
55
+ images = outputs.frames[0]
56
+ image_paths = []
57
+ for idx, image in enumerate(images):
58
+ img_path = os.path.join(
59
+ generation_path, info_dict["save_name"] + f"_{idx}.png"
60
+ )
61
+ os.makedirs(generation_path,exist_ok=True)
62
+ image.save(img_path)
63
+ image_paths.append(img_path)
64
+ return image_paths
hpsv3/cohp/run_cohp.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+ import gc
5
+ import argparse
6
+ import torch
7
+ from PIL import Image
8
+ from transformers import AutoProcessor, AutoModel
9
+
10
+ from generator import Generator
11
+ from hpsv3.inference import HPSv3RewardInferencer
12
+ from hpsv3.cohp.utils_cohp.pipelines import *
13
+ from hpsv3.cohp.utils_cohp.image2image_pipeline import Image2ImagePipeline
14
+
15
+ try:
16
+ from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer
17
+ except:
18
+ print("HPSv2 model not found, skipping HPSv2 related imports.")
19
+
20
+ try:
21
+ import ImageReward as RM
22
+ except:
23
+ print("ImageReward module not found, skipping ImageReward related imports.")
24
+
25
+
26
+ def initialize_hpsv2_model(device, checkpoint_path):
27
+ model_dict = {}
28
+ model, _, preprocess_val = create_model_and_transforms(
29
+ 'ViT-H-14',
30
+ 'laion2B-s32B-b79K',
31
+ device=device,
32
+ precision='amp',
33
+ pretrained_image=False,
34
+ output_dict=True,
35
+ )
36
+ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
37
+ model.load_state_dict(checkpoint['state_dict'])
38
+ model = model.to(device).eval()
39
+ tokenizer = get_tokenizer('ViT-H-14')
40
+
41
+ model_dict['model'] = model
42
+ model_dict['preprocess_val'] = preprocess_val
43
+ return model_dict, tokenizer
44
+
45
+
46
+ def score_hpsv2(model_dict, tokenizer, device, img_paths, prompts):
47
+ model = model_dict['model']
48
+ preprocess_val = model_dict['preprocess_val']
49
+ images = [preprocess_val(Image.open(p)).unsqueeze(0) for p in img_paths]
50
+ images = torch.cat(images, dim=0).to(device)
51
+ texts = tokenizer(prompts).to(device)
52
+
53
+ with torch.no_grad():
54
+ outputs = model(images, texts)
55
+ image_features, text_features = outputs["image_features"], outputs["text_features"]
56
+ logits_per_image = image_features @ text_features.T
57
+ hps_scores = torch.diagonal(logits_per_image).cpu()
58
+ return hps_scores
59
+
60
+
61
+ def calculate_pickscore_probs(model, processor, prompt, images, device):
62
+ image_inputs = processor(images=images, padding=True, return_tensors="pt").to(device)
63
+ text_inputs = processor(text=prompt, padding=True, return_tensors="pt").to(device)
64
+
65
+ with torch.no_grad():
66
+ image_embs = model.get_image_features(**image_inputs)
67
+ image_embs /= torch.norm(image_embs, dim=-1, keepdim=True)
68
+
69
+ text_embs = model.get_text_features(**text_inputs)
70
+ text_embs /= torch.norm(text_embs, dim=-1, keepdim=True)
71
+
72
+ scores = text_embs @ image_embs.T
73
+ return scores
74
+
75
+
76
+ def generate_images(
77
+ reward_type, prompt, index, pipeline_params, pipelines_mapping, inferencer,
78
+ output_dir='cohp_output', num_rounds=5, strength=0.8, device='cuda:1'
79
+ ):
80
+ os.makedirs(output_dir, exist_ok=True)
81
+ result_json_dir = os.path.join(output_dir, 'result_json')
82
+ os.makedirs(result_json_dir, exist_ok=True)
83
+
84
+ info_dict = {
85
+ 'caption': prompt,
86
+ 'width': 1024,
87
+ 'height': 1024,
88
+ 'aspect_ratio': 1,
89
+ 'save_name': f"{index}_origin",
90
+ }
91
+ di_score_pipelines = {}
92
+ intermediate_results_model_pref = {}
93
+ intermediate_results_sample_pref = {}
94
+ max_final_score = 0
95
+
96
+ for pipeline_param in pipeline_params:
97
+ generator = Generator(
98
+ device=device,
99
+ pipe_name=pipeline_param.pipeline_name,
100
+ pipe_type=pipeline_param.pipeline_type,
101
+ pipe_init_kwargs=pipeline_param.pipe_init_kwargs,
102
+ )
103
+ image_paths = generator.generate_imgs(
104
+ info_dict=info_dict,
105
+ generation_path=os.path.join(output_dir, pipeline_param.generation_path),
106
+ batch_size=2,
107
+ device=device,
108
+ seed=random.randint(0, 75859066837),
109
+ weight_dtype=pipeline_param.pipe_init_kwargs["torch_dtype"],
110
+ generation_kwargs=pipeline_param.generation_kwargs
111
+
112
+ )
113
+
114
+ score_list = []
115
+ for image_path in image_paths:
116
+ if reward_type == 'hpsv2':
117
+ score = score_hpsv2(model_dict, tokenizer, device, [image_path], [prompt]).item()
118
+ elif reward_type == 'hpsv3':
119
+ score = inferencer.reward([image_path], [prompt]).cpu().detach()[0][0].item()
120
+ elif reward_type == 'imagereward':
121
+ score = inferencer.score(prompt, [image_path])
122
+ elif reward_type == 'pickscore':
123
+ score = calculate_pickscore_probs(inferencer, processor_pickscore, prompt, [Image.open(image_path)], device)[0][0].item()
124
+ else:
125
+ raise ValueError(f"Unsupported reward type: {reward_type}")
126
+ score_list.append(score)
127
+
128
+ average_score = sum(score_list) / len(score_list)
129
+ pipeline_name = pipelines_mapping[pipeline_param]
130
+ di_score_pipelines[pipeline_name] = average_score
131
+
132
+ intermediate_results_model_pref[pipeline_name] = {
133
+ 'image_paths': image_paths,
134
+ 'scores': score_list,
135
+ 'max_image_path': image_paths[score_list.index(max(score_list))],
136
+ 'max_score': max(score_list),
137
+ }
138
+ generator.pipelines.to("cpu")
139
+ del generator
140
+ torch.cuda.empty_cache()
141
+ gc.collect()
142
+
143
+ # Select the best pipeline based on scores
144
+ best_pipeline = max(di_score_pipelines, key=di_score_pipelines.get)
145
+ best_pipeline_results = intermediate_results_model_pref[best_pipeline]
146
+ chosen_image_path = best_pipeline_results['max_image_path']
147
+
148
+ # Refinement with Image2ImagePipeline
149
+ i2ipipeline = Image2ImagePipeline(best_pipeline)
150
+ for round_num in range(num_rounds):
151
+ if round_num in [3, 4]:
152
+ strength = 0.5
153
+ images = i2ipipeline.generate_image(
154
+ prompt=prompt,
155
+ image_path=chosen_image_path,
156
+ strength=strength,
157
+ batch_size=4,
158
+ save_prefix=f'{index}_{best_pipeline}_image2image_round{round_num + 1}',
159
+ output_dir=output_dir,
160
+ )
161
+
162
+ score_list = []
163
+ for image_path in images:
164
+ if reward_type == 'hpsv2':
165
+ score = score_hpsv2(model_dict, tokenizer, device, [image_path], [prompt]).item()
166
+ elif reward_type == 'hpsv3':
167
+ score = inferencer.reward([image_path], [prompt]).cpu().detach()[0][0].item()
168
+ elif reward_type == 'imagereward':
169
+ score = inferencer.score(prompt, [image_path])
170
+ elif reward_type == 'pickscore':
171
+ score = calculate_pickscore_probs(inferencer, processor_pickscore, prompt, [Image.open(image_path)], device)[0][0].item()
172
+ else:
173
+ raise ValueError(f"Unsupported reward type: {reward_type}")
174
+ score_list.append(score)
175
+
176
+ # Update intermediate results
177
+ intermediate_results_sample_pref[round_num + 1] = {
178
+ 'image_paths': images,
179
+ 'scores': score_list,
180
+ 'max_image_path': images[score_list.index(max(score_list))],
181
+ 'max_score': max(score_list),
182
+ }
183
+
184
+ # Determine best image during refinement
185
+ if max(score_list) > max_final_score:
186
+ max_final_score = max(score_list)
187
+ chosen_image_path = images[score_list.index(max(score_list))]
188
+
189
+ # Save final results
190
+ results = {
191
+ 'prompt': prompt,
192
+ 'best_model': best_pipeline,
193
+ 'final_image_path': chosen_image_path,
194
+ 'model_preference_info': intermediate_results_model_pref,
195
+ 'sample_preference_intermediate_results': intermediate_results_sample_pref,
196
+ }
197
+ with open(os.path.join(result_json_dir, f'{index}.json'), 'w', encoding='utf-8') as file:
198
+ json.dump(results, file, ensure_ascii=False, indent=4)
199
+ return results
200
+
201
+
202
+ if __name__ == "__main__":
203
+ parser = argparse.ArgumentParser(description="Image Generation Script")
204
+ parser.add_argument('--prompt', type=str, required=True, help='The prompt for image generation')
205
+ parser.add_argument('--index', type=str, required=True, help='Index for saving results')
206
+ parser.add_argument('--device', type=str, default='cuda:1', help='Device to run the model on')
207
+ parser.add_argument('--reward_model', type=str, default='hpsv3', help='Reward model to use (hpsv2, hpsv3, pickscore, or imagereward)')
208
+ args = parser.parse_args()
209
+
210
+ # Initialize models and pipelines
211
+ output_dir = f"cohp_output_{args.reward_model}"
212
+ if args.reward_model == 'hpsv2':
213
+ model_dict, tokenizer = initialize_hpsv2_model(args.device, 'pretrained_models/HPS_v2.1_compressed.pt')
214
+ inferencer = model_dict
215
+ elif args.reward_model == 'hpsv3':
216
+ inferencer = HPSv3RewardInferencer(device=args.device)
217
+ elif args.reward_model == 'imagereward':
218
+ inferencer = RM.load("ImageReward-v1.0").to(args.device)
219
+ elif args.reward_model == 'pickscore':
220
+ processor_pickscore = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
221
+ inferencer = AutoModel.from_pretrained("yuvalkirstain/PickScore_v1").eval().to(args.device)
222
+ else:
223
+ raise ValueError("Unsupported reward model.")
224
+
225
+ # Define pipelines
226
+ pipeline_params = [kolors_pipe, sd3_medium_pipe, playground_v2_5_pipe, flux_dev_pipe]
227
+ pipelines_mapping = {
228
+ flux_dev_pipe: 'flux',
229
+ kolors_pipe: 'kolors',
230
+ sd3_medium_pipe: 'sd3',
231
+ playground_v2_5_pipe: 'playground_v2_5',
232
+ }
233
+
234
+ # Generate images
235
+ results = generate_images(
236
+ reward_type=args.reward_model,
237
+ prompt=args.prompt,
238
+ index=args.index,
239
+ pipeline_params=pipeline_params,
240
+ pipelines_mapping=pipelines_mapping,
241
+ inferencer=inferencer,
242
+ output_dir=output_dir,
243
+ num_rounds=4,
244
+ )
hpsv3/cohp/utils_cohp/__init__.py ADDED
File without changes
hpsv3/cohp/utils_cohp/image2image_pipeline.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import FluxImg2ImgPipeline, KolorsImg2ImgPipeline, StableDiffusion3Img2ImgPipeline, StableDiffusionXLImg2ImgPipeline
2
+ from diffusers.utils import load_image
3
+ import torch
4
+ import os
5
+ class Image2ImagePipeline:
6
+ def __init__(
7
+ self, pipe_name, device='cuda'
8
+ ):
9
+ self.pipe_name = pipe_name
10
+ if self.pipe_name == 'flux':
11
+ self.pipeline = FluxImg2ImgPipeline.from_pretrained("pretrained_models/FLUX.1-dev",torch_dtype=torch.bfloat16).to(device)
12
+ self.generation_path = 'generation/flux_dev',
13
+ elif self.pipe_name == 'kolors':
14
+ self.pipeline = KolorsImg2ImgPipeline.from_pretrained("/preflab/shuiyunhao/tasks/HPSv3/pretrained_models/kolors",torch_dtype=torch.bfloat16).to(device)
15
+ self.generation_path = 'generation/kolors',
16
+
17
+ elif self.pipe_name == 'sd3':
18
+ self.pipeline = StableDiffusion3Img2ImgPipeline.from_pretrained("stabilityai/stable-diffusion-3.5-medium",torch_dtype=torch.bfloat16).to(device)
19
+ self.generation_path = 'generation/sd3_medium',
20
+ elif self.pipe_name == 'playground_v2_5':
21
+ self.pipeline = StableDiffusionXLImg2ImgPipeline.from_pretrained("pretrained_models/playground-v2.5-1024px-aesthetic",torch_dtype=torch.bfloat16).to(device)
22
+ self.generation_path = 'generation/playground_v_2_5',
23
+ self.pipeline = self.pipeline.to(torch.bfloat16)
24
+ def generate_image(
25
+ self,
26
+ prompt,
27
+ image_path,
28
+ strength,
29
+ batch_size,
30
+ save_prefix,
31
+ output_dir
32
+ ):
33
+ image_load = load_image(image_path)
34
+ if self.pipe_name == 'flux':
35
+ images = self.pipeline(
36
+ prompt = prompt,
37
+ image=image_load,
38
+ num_images_per_prompt=batch_size,
39
+ strength = strength).images
40
+ else:
41
+
42
+ images = self.pipeline(
43
+ prompt = prompt,
44
+ negative_prompt = '',
45
+ image=image_load,
46
+ num_images_per_prompt=batch_size,
47
+ strength = strength).images
48
+ image_list = []
49
+ for ind,img in enumerate(images):
50
+ print(output_dir,self.generation_path,save_prefix)
51
+ save_path = os.path.join(output_dir,self.generation_path[0],save_prefix+f'_{ind}.png')
52
+ image_list.append(save_path)
53
+ img.save(save_path)
54
+ print(image_list)
55
+ return image_list
56
+
57
+ # pipeline = StableDiffusion3Img2ImgPipeline.from_pretrained("/preflab/shuiyunhao/tasks/HPSv3/pretrained_models/stable-diffusion-3-medium-diffusers",torch_dtype=torch.bfloat16).to('cuda:0')
58
+ # pipeline = pipeline.to(torch.bfloat16)
59
+ # image_load = load_image('/preflab/shuiyunhao/tasks/HPSv3/cohp_output/generation/flux_dev/0_origin_0.png')
60
+ # images = pipeline(
61
+ # prompt = 'a girl',
62
+ # negative_prompt = '',
63
+ # image=image_load,
64
+ # num_images_per_prompt=1,
65
+ # strength = 0.8).images
hpsv3/cohp/utils_cohp/pipelines.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ class PipelineParam:
3
+ pipeline_name: str
4
+ pipeline_type: str
5
+ generation_path: str
6
+ pipe_init_kwargs: dict
7
+ generation_kwargs: dict
8
+ base_resolution: int
9
+ force_aspect_ratio: int
10
+
11
+ def __init__(self, pipeline_name: str, generation_path: str, pipeline_type = 't2i',
12
+ pipe_init_kwargs: dict = None, generation_kwargs: dict = None,
13
+ base_resolution: int = 1024, force_aspect_ratio: int = None):
14
+ self.pipeline_name = pipeline_name
15
+ self.pipeline_type = pipeline_type
16
+ self.generation_path = generation_path
17
+ self.pipe_init_kwargs = pipe_init_kwargs if pipe_init_kwargs is not None else {}
18
+ self.generation_kwargs = generation_kwargs if generation_kwargs is not None else {}
19
+ self.base_resolution = base_resolution
20
+ self.force_aspect_ratio = force_aspect_ratio
21
+
22
+ flux_dev_pipe = PipelineParam(
23
+ pipeline_name='pretrained_models/FLUX.1-dev',
24
+ generation_path=f'generation/flux_dev',
25
+ pipe_init_kwargs={
26
+ "torch_dtype": torch.bfloat16,
27
+ },
28
+ base_resolution=1024,
29
+ generation_kwargs={
30
+ "guidance_scale": 3.5,
31
+ "num_inference_steps": 28,
32
+ "max_sequence_length": 512,
33
+ }
34
+ )
35
+
36
+ flux_schnell_pipe = PipelineParam(
37
+ pipeline_name='/mnt2/share/huggingface_models/FLUX.1-schnell',
38
+ generation_path=f'generation/flux_schnell',
39
+ pipe_init_kwargs={
40
+ "torch_dtype": torch.bfloat16,
41
+ },
42
+ base_resolution=1024,
43
+ generation_kwargs={
44
+ "guidance_scale": 3.5,
45
+ "num_inference_steps": 4,
46
+ }
47
+ )
48
+
49
+
50
+ sd3_medium_pipe = PipelineParam(
51
+ pipeline_name='pretrained_models/stable-diffusion-3-medium-diffusers',
52
+ generation_path=f'generation/sd3_medium',
53
+ pipe_init_kwargs={
54
+ "torch_dtype": torch.float16,
55
+ },
56
+ base_resolution=1024,
57
+ generation_kwargs={
58
+ "guidance_scale": 7.0,
59
+ "num_inference_steps": 28,
60
+ }
61
+ )
62
+
63
+ sd_xl_pipe = PipelineParam(
64
+ pipeline_name='pretrained_models/stable-diffusion-xl-base-1.0',
65
+ generation_path=f'generation/sd_xl',
66
+ pipe_init_kwargs={
67
+ "torch_dtype": torch.float16,
68
+ },
69
+ base_resolution=1024,
70
+ generation_kwargs={
71
+ "guidance_scale": 5,
72
+ "num_inference_steps": 50,
73
+ }
74
+ )
75
+
76
+ sd_1_5_pipe = PipelineParam(
77
+ pipeline_name='pretrained_models/stable-diffusion-v1-5',
78
+ generation_path=f'generation/sd_1_5',
79
+ pipe_init_kwargs={
80
+ "torch_dtype": torch.float16,
81
+ },
82
+ base_resolution=512,
83
+ generation_kwargs={
84
+ }
85
+ )
86
+
87
+ vq_diffusion_pipe = PipelineParam(
88
+ pipeline_name='pretrained_models/vq-diffusion-ithq',
89
+ generation_path=f'generation/vq_diffusion',
90
+ pipe_init_kwargs={
91
+ "torch_dtype": torch.float16,
92
+ },
93
+ base_resolution=256,
94
+ generation_kwargs={}
95
+ )
96
+
97
+ sd_2_pipe = PipelineParam(
98
+ pipeline_name='pretrained_models/stable-diffusion-2',
99
+ generation_path=f'generation/sd_2',
100
+ pipe_init_kwargs={
101
+ "torch_dtype": torch.float16,
102
+ },
103
+ base_resolution=512,
104
+ force_aspect_ratio=1,
105
+ )
106
+
107
+ sd_1_1_pipe = PipelineParam(
108
+ pipeline_name='pretrained_models/stable-diffusion-v1-1',
109
+ generation_path=f'generation/sd_1_1',
110
+ pipe_init_kwargs={"torch_dtype": torch.float16,},
111
+ base_resolution=512,
112
+ force_aspect_ratio=1,
113
+ )
114
+
115
+ sd_1_4_pipe = PipelineParam(
116
+ pipeline_name='pretrained_models/stable-diffusion-v1-4',
117
+ generation_path=f'generation/sd_1_4',
118
+ pipe_init_kwargs={
119
+ "torch_dtype": torch.float16,
120
+ },
121
+ base_resolution=512,
122
+ force_aspect_ratio=1,
123
+ )
124
+
125
+ sd_2_1_pipe = PipelineParam(
126
+ pipeline_name='pretrained_models/stable-diffusion-2-1-base',
127
+ generation_path=f'generation/sd_2_1',
128
+ pipe_init_kwargs={
129
+ "torch_dtype": torch.float16,
130
+ },
131
+ base_resolution=512,
132
+ force_aspect_ratio=1,
133
+ )
134
+
135
+ openjourney_pipe = PipelineParam(
136
+ pipeline_name='pretrained_models/openjourney',
137
+ generation_path=f'generation/openjourney',
138
+ pipe_init_kwargs={
139
+ "torch_dtype": torch.float16,
140
+ },
141
+ base_resolution=512,
142
+ force_aspect_ratio=1,
143
+ )
144
+
145
+ playground_v2_5_pipe = PipelineParam(
146
+ pipeline_name='pretrained_models/playground-v2.5-1024px-aesthetic',
147
+ generation_path=f'generation/playground_v_2_5',
148
+ pipe_init_kwargs={
149
+ "torch_dtype": torch.float16,
150
+ },
151
+ base_resolution=1024,
152
+ )
153
+
154
+ versatile_pipe = PipelineParam(
155
+ pipeline_name='pretrained_models/versatile-diffusion',
156
+ generation_path=f'generation/versatile',
157
+ pipe_init_kwargs={
158
+ "torch_dtype": torch.float16,
159
+ },
160
+ base_resolution=512,
161
+ force_aspect_ratio=1,
162
+ )
163
+
164
+ glide_pipe = PipelineParam(
165
+ pipeline_name='pretrained_models/glide-base',
166
+ generation_path=f'generation/glide',
167
+ pipe_init_kwargs={
168
+ "torch_dtype": torch.float16,
169
+ },
170
+ base_resolution=512,
171
+ force_aspect_ratio=1,
172
+ )
173
+
174
+ sd_3_5_medium_pipe = PipelineParam(
175
+ pipeline_name='stabilityai/stable-diffusion-3.5-medium',
176
+ generation_path=f'generation/sd_3_5_medium',
177
+ pipe_init_kwargs={
178
+ "torch_dtype": torch.bfloat16,
179
+ },
180
+ base_resolution=1024,
181
+ generation_kwargs={
182
+ "num_inference_steps": 40,
183
+ "guidance_scale": 4.5,
184
+ }
185
+ )
186
+
187
+ sd_3_5_large_pipe = PipelineParam(
188
+ pipeline_name='stabilityai/stable-diffusion-3.5-large',
189
+ generation_path=f'generation/sd_3_5_large',
190
+ pipe_init_kwargs={
191
+ "torch_dtype": torch.bfloat16,
192
+ },
193
+ base_resolution=1024,
194
+ generation_kwargs={
195
+ "num_inference_steps": 28,
196
+ "guidance_scale": 3.5,
197
+ }
198
+ )
199
+
200
+ kolors_pipe = PipelineParam(
201
+ pipeline_name='pretrained_models/Kolors-diffusers',
202
+ generation_path=f'generation/kolors',
203
+ pipe_init_kwargs={
204
+ "torch_dtype": torch.float16,
205
+ 'variant': 'fp16',
206
+ },
207
+ base_resolution=1024,
208
+ generation_kwargs={
209
+ "num_inference_steps": 50,
210
+ "guidance_scale": 5.0,
211
+ }
212
+ )
213
+
214
+ cogview4_pipe = PipelineParam(
215
+ pipeline_name='pretrained_models/CogView4-6B',
216
+ generation_path=f'generation/cogview4',
217
+ pipe_init_kwargs={
218
+ "torch_dtype": torch.bfloat16,
219
+ },
220
+ base_resolution=1024,
221
+ generation_kwargs={
222
+ "num_inference_steps": 50,
223
+ "guidance_scale": 3.5,
224
+ }
225
+ )
226
+
227
+ pixart_sigma_pipe = PipelineParam(
228
+ pipeline_name='pretrained_models/PixArt-Sigma-XL-2-1024-MS',
229
+ generation_path=f'generation/pixart_sigma',
230
+ pipeline_type='t2i',
231
+ pipe_init_kwargs={
232
+ "torch_dtype": torch.bfloat16,
233
+ },
234
+ base_resolution=1024,
235
+ )
236
+
237
+ hunyuanvideo_pipe = PipelineParam(
238
+ pipeline_name='pretrained_models/hunyuanvideo_diffusers',
239
+ generation_path=f'generation/hunyuanvideo',
240
+ pipe_init_kwargs={
241
+ "torch_dtype": torch.bfloat16,
242
+ },
243
+ base_resolution=1024,
244
+ pipeline_type='t2v',
245
+ generation_kwargs={
246
+ "num_inference_steps": 30,
247
+ "num_frames": 1,
248
+ }
249
+ )
250
+
251
+ hunyuandit_pipe = PipelineParam(
252
+ pipeline_name='pretrained_models/HunyuanDiT-v1.2-Diffusers',
253
+ generation_path=f'generation/hunyuandit',
254
+ pipe_init_kwargs={
255
+ "torch_dtype": torch.float16,
256
+ },
257
+ base_resolution=1024,
258
+ pipeline_type='t2i',
259
+ generation_kwargs={
260
+ }
261
+ )
262
+
263
+ # API models
264
+ # Fal.ai
265
+ flux_pro_v1_1_ultr_pipe = PipelineParam(
266
+ pipeline_name='fal-ai/flux-pro/v1.1-ultra',
267
+ generation_path=f'generation/flux_pro_v1_1_ultra',
268
+ base_resolution=1024,
269
+ generation_kwargs={
270
+ "enable_safety_checker": False,
271
+ "num_images": 1,
272
+ # "aspect_ratio": "1:1",
273
+ "output_format": "jpeg",
274
+ "safety_tolerance": 5,
275
+ }
276
+ )
277
+
278
+ recraftv3_pipe = PipelineParam(
279
+ pipeline_name='fal-ai/recraft-v3',
280
+ generation_path=f'generation/recraftv3',
281
+ base_resolution=1024,
282
+ generation_kwargs={
283
+ "enable_safety_checker": False,
284
+ "num_images": 1,
285
+ # "aspect_ratio": "1:1",
286
+ "output_format": "jpeg",
287
+ "safety_tolerance": 5,
288
+ }
289
+ )
290
+
hpsv3/cohp/utils_cohp/utils.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ try:
3
+ import fal_client
4
+ except:
5
+ fal_client = None
6
+ try:
7
+ from diffusers import AutoPipelineForText2Image, DiffusionPipeline
8
+ except:
9
+ AutoPipelineForText2Image = None
10
+ DiffusionPipeline = None
11
+
12
+ import json
13
+ import diffusers
14
+ from functools import partial
15
+ import os
16
+ # export FAL_KEY="YOUR_API_KEY"
17
+ os.environ['FAL_KEY'] = 'YOUR_API_KEY'
18
+
19
+ def init_pipelines(pipe_name, pipe_init_kwargs, device=None):
20
+
21
+ try:
22
+ pipeline = AutoPipelineForText2Image.from_pretrained(pipe_name, **pipe_init_kwargs).to(device)
23
+ except Exception as e:
24
+ # try:
25
+ config = json.load(open(os.path.join(pipe_name, 'model_index.json')))
26
+ class_name_str = config['_class_name']
27
+ pipeline_class = getattr(diffusers, class_name_str)
28
+ pipeline = pipeline_class.from_pretrained(pipe_name, **pipe_init_kwargs).to(device)
29
+
30
+ return pipeline
31
+
32
+
33
+ def init_pipeline_from_names(pipe_names, weight_dtype):
34
+ pipelines_dict = {}
35
+ for name in pipe_names:
36
+ pipeline = AutoPipelineForText2Image.from_pretrained(name, torch_dtype=weight_dtype)
37
+ pipelines_dict[name] = pipeline
38
+ return pipelines_dict
39
+
40
+
41
+ def on_queue_update(update):
42
+ if isinstance(update, fal_client.InProgress):
43
+ for log in update.logs:
44
+ print(log["message"])
45
+
46
+ def gen_with_api(pipe_names, generation_kwargs):
47
+ result = fal_client.subscribe(
48
+ pipe_names,
49
+ arguments=generation_kwargs,
50
+ with_logs=True,
51
+ on_queue_update=on_queue_update,
52
+ )
53
+ return result
hpsv3/config/HPSv3_7B.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model Configuration
2
+ rm_head_type: "ranknet"
3
+ lora_enable: False
4
+ vision_lora: False
5
+ freeze_vision_tower: False
6
+ freeze_llm: False
7
+ tune_merger: True
8
+ model_name_or_path: "Qwen/Qwen2-VL-7B-Instruct"
9
+ num_lora_modules: -1
10
+ lora_r: 512
11
+ lora_alpha: 1024
12
+ lora_namespan_exclude: ['lm_head', 'rm_head', 'embed_tokens']
13
+
14
+ # Data Configuration
15
+ confidence_threshold: 0.95
16
+ tied_threshold: null
17
+ max_pixels: 200704 # 256 * 28 * 28
18
+ min_pixels: 200704
19
+ with_instruction: true
20
+
21
+ train_json_list:
22
+ - example_train.json
23
+ test_json_list:
24
+ - ["Valid Set 1", ["example_set_1_part1.json", "example_set_1_part2.json"]]
25
+ - ['Valid Set 2',["example_set_2_part1.json"]]
26
+
27
+ soft_label: False
28
+ output_dir: output_models
29
+ use_special_tokens: true
30
+ reward_token: "special"
31
+ output_dim: 2
32
+ loss_type: "uncertainty"
33
+
34
+ # Training Configuration
35
+ disable_flash_attn2: False
36
+ per_device_train_batch_size: 2
37
+ per_device_eval_batch_size: 8
38
+ gradient_accumulation_steps: 4
39
+ num_train_epochs: 10
40
+ learning_rate: 2.0e-6
41
+ special_token_lr: 2.0e-6
42
+ warmup_ratio: 0.05
43
+ lr_scheduler_type: "constant_with_warmup"
44
+ gradient_checkpointing: True
45
+ gradient_checkpointing_kwargs: {"use_reentrant": False}
46
+
47
+ # Evaluation and Logging
48
+ eval_strategy: "steps"
49
+ logging_epochs: 0.01
50
+ eval_epochs: 0.1
51
+ save_epochs: 0.1
52
+ report_to: tensorboard
53
+
54
+ # System Configuration
55
+ bf16: True
56
+ torch_dtype: "bfloat16"
57
+ deepspeed: hpsv3/config/ds_config/zero2.json
58
+ save_only_model: True
59
+ save_full_model: True
60
+ dataloader_num_workers: 8
hpsv3/config/ds_config/zero0.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": "auto"
12
+ },
13
+ "train_micro_batch_size_per_gpu": "auto",
14
+ "train_batch_size": "auto",
15
+ "gradient_accumulation_steps": "auto",
16
+ "zero_optimization": {
17
+ "stage": 0
18
+ }
19
+ }
hpsv3/config/ds_config/zero2.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": "auto"
12
+ },
13
+ "train_micro_batch_size_per_gpu": "auto",
14
+ "train_batch_size": "auto",
15
+ "gradient_accumulation_steps": "auto",
16
+ "zero_optimization": {
17
+ "stage": 2,
18
+ "overlap_comm": true,
19
+ "contiguous_gradients": true,
20
+ "sub_group_size": 1e9,
21
+ "reduce_bucket_size": "auto"
22
+ }
23
+ }
hpsv3/config/ds_config/zero3.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": "auto"
12
+ },
13
+ "train_micro_batch_size_per_gpu": "auto",
14
+ "train_batch_size": "auto",
15
+ "gradient_accumulation_steps": "auto",
16
+ "zero_optimization": {
17
+ "stage": 3,
18
+ "overlap_comm": true,
19
+ "contiguous_gradients": true,
20
+ "sub_group_size": 1e9,
21
+ "reduce_bucket_size": "auto",
22
+ "stage3_prefetch_bucket_size": "auto",
23
+ "stage3_param_persistence_threshold": "auto",
24
+ "stage3_max_live_parameters": 1e9,
25
+ "stage3_max_reuse_distance": 1e9,
26
+ "stage3_gather_16bit_weights_on_model_save": true
27
+ }
28
+ }
hpsv3/dataset/data_collator_qwen.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+ from dataclasses import dataclass, field
3
+ from typing import Optional, List, Union
4
+ import numpy as np
5
+ import pandas as pd
6
+ import torch
7
+ from hpsv3.dataset.utils import process_vision_info
8
+ from torch.utils.data import Dataset
9
+ import torchvision.transforms.functional as F
10
+
11
+ INSTRUCTION = """
12
+ You are tasked with evaluating a generated image based on Visual Quality and Text Alignment and give a overall score to estimate the human preference. Please provide a rating from 0 to 10, with 0 being the worst and 10 being the best.
13
+
14
+ **Visual Quality:**
15
+ Evaluate the overall visual quality of the image. The following sub-dimensions should be considered:
16
+ - **Reasonableness:** The image should not contain any significant biological or logical errors, such as abnormal body structures or nonsensical environmental setups.
17
+ - **Clarity:** Evaluate the sharpness and visibility of the image. The image should be clear and easy to interpret, with no blurring or indistinct areas.
18
+ - **Detail Richness:** Consider the level of detail in textures, materials, lighting, and other visual elements (e.g., hair, clothing, shadows).
19
+ - **Aesthetic and Creativity:** Assess the artistic aspects of the image, including the color scheme, composition, atmosphere, depth of field, and the overall creative appeal. The scene should convey a sense of harmony and balance.
20
+ - **Safety:** The image should not contain harmful or inappropriate content, such as political, violent, or adult material. If such content is present, the image quality and satisfaction score should be the lowest possible.
21
+
22
+ **Text Alignment:**
23
+ Assess how well the image matches the textual prompt across the following sub-dimensions:
24
+ - **Subject Relevance** Evaluate how accurately the subject(s) in the image (e.g., person, animal, object) align with the textual description. The subject should match the description in terms of number, appearance, and behavior.
25
+ - **Style Relevance:** If the prompt specifies a particular artistic or stylistic style, evaluate how well the image adheres to this style.
26
+ - **Contextual Consistency**: Assess whether the background, setting, and surrounding elements in the image logically fit the scenario described in the prompt. The environment should support and enhance the subject without contradictions.
27
+ - **Attribute Fidelity**: Check if specific attributes mentioned in the prompt (e.g., colors, clothing, accessories, expressions, actions) are faithfully represented in the image. Minor deviations may be acceptable, but critical attributes should be preserved.
28
+ - **Semantic Coherence**: Evaluate whether the overall meaning and intent of the prompt are captured in the image. The generated content should not introduce elements that conflict with or distort the original description.
29
+ Textual prompt - {text_prompt}
30
+
31
+
32
+ """
33
+
34
+ INSTRUCTION_debug = """
35
+ {text_prompt}
36
+ """
37
+
38
+ prompt_with_special_token = """
39
+ Please provide the overall ratings of this image: <|Reward|>
40
+
41
+ END
42
+ """
43
+
44
+ prompt_without_special_token = """
45
+ Please provide the overall ratings of this image:
46
+ """
47
+
48
+
49
+ class QWen2VLDataCollator:
50
+ def __init__(
51
+ self,
52
+ processor,
53
+ with_instruction=True,
54
+ max_pixels=256 * 28 * 28, # Default max pixels
55
+ min_pixels=256 * 28 * 28, # Default min pixels
56
+ use_special_tokens=True,
57
+ ):
58
+ self.processor = processor
59
+ self.with_instruction = with_instruction
60
+ self.max_pixels = max_pixels
61
+ self.min_pixels = min_pixels
62
+ self.use_special_tokens = use_special_tokens
63
+
64
+ def _clean_message(
65
+ self,
66
+ texts,
67
+ images,
68
+ max_pixels=256 * 28 * 28,
69
+ min_pixels=256 * 28 * 28,
70
+ with_instruction=True,
71
+ use_special_tokens=True,
72
+ ):
73
+ """
74
+ remove unnecessary keys from message(very very necessary)
75
+ """
76
+ message_list = []
77
+ for text, image in zip(texts, images):
78
+ out_message = [
79
+ {
80
+ "role": "user",
81
+ "content": [
82
+ {
83
+ "type": "image",
84
+ "image": image,
85
+ "min_pixels": min_pixels,
86
+ "max_pixels": max_pixels,
87
+ },
88
+ {
89
+ "type": "text",
90
+ "text": (
91
+ INSTRUCTION.format(text_prompt=text)
92
+ + prompt_with_special_token
93
+ if use_special_tokens
94
+ else prompt_without_special_token
95
+ ),
96
+ },
97
+ ],
98
+ }
99
+ ]
100
+
101
+ message_list.append(out_message)
102
+
103
+ return message_list
104
+
105
+ def _pad_sequence(self, sequences, attention_mask, max_len, padding_side="right"):
106
+ """
107
+ Pad the sequences to the maximum length.
108
+ """
109
+ assert padding_side in ["right", "left"]
110
+ if sequences.shape[1] >= max_len:
111
+ return sequences, attention_mask
112
+
113
+ pad_len = max_len - sequences.shape[1]
114
+ padding = (0, pad_len) if padding_side == "right" else (pad_len, 0)
115
+
116
+ sequences_padded = torch.nn.functional.pad(
117
+ sequences, padding, "constant", self.processor.tokenizer.pad_token_id
118
+ )
119
+ attention_mask_padded = torch.nn.functional.pad(
120
+ attention_mask, padding, "constant", 0
121
+ )
122
+
123
+ return sequences_padded, attention_mask_padded
124
+
125
+ def __call__(self, inputs, with_instruction=True):
126
+ """
127
+ Preprocess inputs to token sequences and return a batch
128
+ """
129
+ images_1, images_2, texts_1, texts_2 = [], [], [], []
130
+
131
+ for idx, batch in enumerate(inputs):
132
+ texts_1.append(batch["text_1"])
133
+ texts_2.append(batch["text_2"])
134
+ images_1.append(batch["image_1"])
135
+ images_2.append(batch["image_2"])
136
+
137
+ messages_batch_1 = self._clean_message(
138
+ texts_1,
139
+ images_1,
140
+ max_pixels=self.max_pixels,
141
+ min_pixels=self.min_pixels,
142
+ with_instruction=self.with_instruction,
143
+ use_special_tokens=self.use_special_tokens,
144
+ )
145
+ messages_batch_2 = self._clean_message(
146
+ texts_2,
147
+ images_2,
148
+ max_pixels=self.max_pixels,
149
+ min_pixels=self.min_pixels,
150
+ with_instruction=self.with_instruction,
151
+ use_special_tokens=self.use_special_tokens,
152
+ )
153
+ # import pdb; pdb.set_trace()
154
+ image_inputs_1, _ = process_vision_info(messages_batch_1)
155
+ image_inputs_2, _ = process_vision_info(messages_batch_2)
156
+ image_inputs_1 = [
157
+ np.array(image_inputs_1[i]) / 255.0 for i in range(len(image_inputs_1))
158
+ ]
159
+ image_inputs_2 = [
160
+ np.array(image_inputs_2[i]) / 255.0 for i in range(len(image_inputs_2))
161
+ ]
162
+ do_rescale = False
163
+
164
+ batch_1 = self.processor(
165
+ text=self.processor.apply_chat_template(
166
+ messages_batch_1, tokenize=False, add_generation_prompt=True
167
+ ),
168
+ images=image_inputs_1,
169
+ videos=None,
170
+ padding=True,
171
+ return_tensors="pt",
172
+ images_kwargs={"do_rescale": do_rescale},
173
+ )
174
+ batch_2 = self.processor(
175
+ text=self.processor.apply_chat_template(
176
+ messages_batch_2, tokenize=False, add_generation_prompt=True
177
+ ),
178
+ images=image_inputs_2,
179
+ videos=None,
180
+ padding=True,
181
+ return_tensors="pt",
182
+ images_kwargs={"do_rescale": do_rescale},
183
+ )
184
+
185
+ # pdb.set_trace()
186
+ max_len = max(batch_1["input_ids"].shape[1], batch_2["input_ids"].shape[1])
187
+ batch_1["input_ids"], batch_1["attention_mask"] = self._pad_sequence(
188
+ batch_1["input_ids"], batch_1["attention_mask"], max_len, "right"
189
+ )
190
+ batch_2["input_ids"], batch_2["attention_mask"] = self._pad_sequence(
191
+ batch_2["input_ids"], batch_2["attention_mask"], max_len, "right"
192
+ )
193
+
194
+ batch = {
195
+ "batch_1": batch_1,
196
+ "batch_2": batch_2,
197
+ "choice_dist": torch.stack([batch["choice_dist"] for batch in inputs]),
198
+ # Store original text prompts for visualization
199
+ "text_1": texts_1,
200
+ "text_2": texts_2,
201
+ "image_1": image_inputs_1,
202
+ "image_2": image_inputs_2,
203
+ }
204
+
205
+ return batch
hpsv3/dataset/pairwise_dataset.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset, DataLoader
3
+ import random
4
+ import json
5
+ import os
6
+ from tqdm import tqdm
7
+
8
+ class PairwiseOriginalDataset(Dataset):
9
+ def __init__(
10
+ self,
11
+ json_list,
12
+ soft_label=False,
13
+ confidence_threshold=None,
14
+ ):
15
+ self.samples = []
16
+ for json_file in json_list:
17
+ with open(json_file, "r") as f:
18
+ data = json.load(f)
19
+ self.samples.extend(data)
20
+
21
+ self.soft_label = soft_label
22
+ self.confidence_threshold = confidence_threshold
23
+
24
+ if confidence_threshold is not None:
25
+ new_samples = []
26
+ for sample in tqdm(
27
+ self.samples, desc="Filtering samples according to confidence threshold"
28
+ ):
29
+ if sample.get("confidence", float("inf")) >= confidence_threshold:
30
+ new_samples.append(sample)
31
+ self.samples = new_samples
32
+
33
+ def __len__(self):
34
+ return len(self.samples)
35
+
36
+ def __getitem__(self, idx):
37
+ while True:
38
+ index = idx
39
+ try:
40
+ return self.get_single_item(index)
41
+ except Exception as e:
42
+ print(f"Error processing sample at index {idx}: {e}")
43
+ import traceback
44
+ traceback.print_exc()
45
+ index = random.randint(0, len(self.samples) - 1)
46
+ if index == idx:
47
+ continue
48
+ idx = index
49
+
50
+ def get_single_item(self, idx):
51
+ sample = self.samples[idx]
52
+ # Load image paths
53
+ image_1 = sample["path1"]
54
+ image_2 = sample["path2"]
55
+ assert os.path.exists(image_1) and os.path.exists(image_2), f'{image_1} or {image_2}'
56
+ text_1 = sample["prompt"]
57
+ text_2 = sample["prompt"]
58
+
59
+ # Process Label
60
+ if self.soft_label:
61
+ choice_dist = sorted(sample["choice_dist"], reverse=True)
62
+ assert (
63
+ torch.sum(torch.tensor(choice_dist)) > 0
64
+ ), "Choice distribution cannot be zero."
65
+ label = torch.tensor(choice_dist[0]) / torch.sum(torch.tensor(choice_dist))
66
+ else:
67
+ label = torch.tensor(1).float()
68
+ # breakpoint()
69
+ return {
70
+ "image_1": image_1,
71
+ "image_2": image_2,
72
+ "text_1": text_1,
73
+ "text_2": text_2,
74
+ "label": label,
75
+ "confidence": sample.get("confidence", 1.0),
76
+ "choice_dist": torch.tensor(sample.get("choice_dist", [1.0, 0.0])),
77
+ }
hpsv3/dataset/utils.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+ ## This file is modified from https://github.com/kq-chen/qwen-vl-utils/blob/main/src/qwen_vl_utils/vision_process.py
5
+ import base64
6
+ import logging
7
+ import math
8
+ import os
9
+ import sys
10
+ import time
11
+ import warnings
12
+ from functools import lru_cache
13
+ from io import BytesIO
14
+
15
+ import requests
16
+ import torch
17
+ import torchvision
18
+ from packaging import version
19
+ from PIL import Image
20
+ from torchvision import io, transforms
21
+ from torchvision.transforms import InterpolationMode
22
+
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ IMAGE_FACTOR = 28
27
+ MIN_PIXELS = 4 * 28 * 28
28
+ MAX_PIXELS = 16384 * 28 * 28
29
+ MAX_RATIO = 200
30
+
31
+ VIDEO_MIN_PIXELS = 128 * 28 * 28
32
+ VIDEO_MAX_PIXELS = 768 * 28 * 28
33
+ VIDEO_TOTAL_PIXELS = 24576 * 28 * 28
34
+ FRAME_FACTOR = 2
35
+ FPS = 2.0
36
+ FPS_MIN_FRAMES = 4
37
+ FPS_MAX_FRAMES = 768
38
+
39
+
40
+ def round_by_factor(number: int, factor: int) -> int:
41
+ """Returns the closest integer to 'number' that is divisible by 'factor'."""
42
+ return round(number / factor) * factor
43
+
44
+
45
+ def ceil_by_factor(number: int, factor: int) -> int:
46
+ """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
47
+ return math.ceil(number / factor) * factor
48
+
49
+
50
+ def floor_by_factor(number: int, factor: int) -> int:
51
+ """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
52
+ return math.floor(number / factor) * factor
53
+
54
+
55
+ def smart_resize(
56
+ height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
57
+ ) -> tuple[int, int]:
58
+ """
59
+ Rescales the image so that the following conditions are met:
60
+
61
+ 1. Both dimensions (height and width) are divisible by 'factor'.
62
+
63
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
64
+
65
+ 3. The aspect ratio of the image is maintained as closely as possible.
66
+ """
67
+ if max(height, width) / min(height, width) > MAX_RATIO:
68
+ raise ValueError(
69
+ f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
70
+ )
71
+ h_bar = max(factor, round_by_factor(height, factor))
72
+ w_bar = max(factor, round_by_factor(width, factor))
73
+ if h_bar * w_bar > max_pixels:
74
+ beta = math.sqrt((height * width) / max_pixels)
75
+ h_bar = floor_by_factor(height / beta, factor)
76
+ w_bar = floor_by_factor(width / beta, factor)
77
+ elif h_bar * w_bar < min_pixels:
78
+ beta = math.sqrt(min_pixels / (height * width))
79
+ h_bar = ceil_by_factor(height * beta, factor)
80
+ w_bar = ceil_by_factor(width * beta, factor)
81
+ return h_bar, w_bar
82
+
83
+
84
+ def fetch_image(ele: dict[str, str | Image.Image], size_factor: int = IMAGE_FACTOR) -> Image.Image:
85
+ if "image" in ele:
86
+ image = ele["image"]
87
+ else:
88
+ image = ele["image_url"]
89
+ image_obj = None
90
+ if isinstance(image, Image.Image):
91
+ image_obj = image
92
+ elif isinstance(image, torch.Tensor):
93
+ image_obj = image
94
+ elif image.startswith("http://") or image.startswith("https://"):
95
+ image_obj = Image.open(requests.get(image, stream=True).raw)
96
+ elif image.startswith("file://"):
97
+ image_obj = Image.open(image[7:])
98
+ elif image.startswith("data:image"):
99
+ if "base64," in image:
100
+ _, base64_data = image.split("base64,", 1)
101
+ data = base64.b64decode(base64_data)
102
+ image_obj = Image.open(BytesIO(data))
103
+ else:
104
+ image_obj = Image.open(image)
105
+ if image_obj is None:
106
+ raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
107
+ if isinstance(image_obj, Image.Image):
108
+ image = image_obj.convert("RGB")
109
+ ## resize
110
+ if "resized_height" in ele and "resized_width" in ele:
111
+ resized_height, resized_width = smart_resize(
112
+ ele["resized_height"],
113
+ ele["resized_width"],
114
+ factor=size_factor,
115
+ )
116
+ else:
117
+ if isinstance(image, torch.Tensor):
118
+ shape = image.shape
119
+ if len(shape) == 4:
120
+ if shape[1] in [1, 3]: # Likely [B, C, H, W]
121
+ height, width = shape[2], shape[3]
122
+ image_mode = 'NCHW'
123
+ elif shape[3] in [1, 3]: # Likely [B, H, W, C]
124
+ height, width = shape[1], shape[2]
125
+ image_mode = 'NHWC'
126
+
127
+ elif len(shape) == 3:
128
+ if shape[0] in [1, 3]: # Likely [C, H, W]
129
+ height, width = shape[1], shape[2]
130
+ image_mode = 'CHW'
131
+ elif shape[2] in [1, 3]: # Likely [H, W, C]
132
+ height, width = shape[0], shape[1]
133
+ image_mode = 'HWC'
134
+ else:
135
+ raise ValueError(f"Cannot determine tensor image format from shape {shape}")
136
+ else:
137
+ raise ValueError(f"Unsupported tensor image shape: {shape}")
138
+ else:
139
+ width, height = image.size
140
+ min_pixels = ele.get("min_pixels", MIN_PIXELS)
141
+ max_pixels = ele.get("max_pixels", MAX_PIXELS)
142
+ resized_height, resized_width = smart_resize(
143
+ height,
144
+ width,
145
+ factor=size_factor,
146
+ min_pixels=min_pixels,
147
+ max_pixels=max_pixels,
148
+ )
149
+
150
+ if isinstance(image, torch.Tensor):
151
+ if image_mode == 'NCHW':
152
+ image = transforms.functional.resize(
153
+ image, [resized_height, resized_width], interpolation=InterpolationMode.BICUBIC, antialias=True
154
+ )
155
+ elif image_mode == 'NHWC':
156
+ image = transforms.functional.resize(
157
+ image.permute(0, 3, 1, 2), [resized_height, resized_width], interpolation=InterpolationMode.BICUBIC, antialias=True
158
+ )
159
+ elif image_mode == 'CHW':
160
+ image = image.unsqueeze(0) # Add batch dimension
161
+ image = transforms.functional.resize(
162
+ image, [resized_height, resized_width], interpolation=InterpolationMode.BICUBIC, antialias=True
163
+ )
164
+ elif image_mode == 'HWC':
165
+ image = image.permute(2, 0, 1).unsqueeze(0) # Add batch dimension and change to CHW
166
+ image = transforms.functional.resize(
167
+ image, [resized_height, resized_width], interpolation=InterpolationMode.BICUBIC, antialias=True
168
+ )
169
+
170
+ else:
171
+ # If the image is a PIL Image, we resize it using PIL.
172
+ if image.mode != "RGB":
173
+ image = image.convert("RGB")
174
+ image = image.resize((resized_width, resized_height), Image.BICUBIC)
175
+
176
+ return image
177
+
178
+
179
+ def smart_nframes(
180
+ ele: dict,
181
+ total_frames: int,
182
+ video_fps: int | float,
183
+ ) -> int:
184
+ """calculate the number of frames for video used for model inputs.
185
+
186
+ Args:
187
+ ele (dict): a dict contains the configuration of video.
188
+ support either `fps` or `nframes`:
189
+ - nframes: the number of frames to extract for model inputs.
190
+ - fps: the fps to extract frames for model inputs.
191
+ - min_frames: the minimum number of frames of the video, only used when fps is provided.
192
+ - max_frames: the maximum number of frames of the video, only used when fps is provided.
193
+ total_frames (int): the original total number of frames of the video.
194
+ video_fps (int | float): the original fps of the video.
195
+
196
+ Raises:
197
+ ValueError: nframes should in interval [FRAME_FACTOR, total_frames].
198
+
199
+ Returns:
200
+ int: the number of frames for video used for model inputs.
201
+ """
202
+ assert not ("fps" in ele and "nframes" in ele), "Only accept either `fps` or `nframes`"
203
+ if "nframes" in ele:
204
+ nframes = round_by_factor(ele["nframes"], FRAME_FACTOR)
205
+ else:
206
+ fps = ele.get("fps", FPS)
207
+ min_frames = ceil_by_factor(ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR)
208
+ max_frames = floor_by_factor(ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), FRAME_FACTOR)
209
+ nframes = total_frames / video_fps * fps
210
+ nframes = min(max(nframes, min_frames), max_frames)
211
+ nframes = round_by_factor(nframes, FRAME_FACTOR)
212
+ if nframes > total_frames:
213
+ nframes = total_frames
214
+ if not (FRAME_FACTOR <= nframes and nframes <= total_frames):
215
+ raise ValueError(f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}.")
216
+ return nframes
217
+
218
+
219
+ def _read_video_torchvision(
220
+ ele: dict,
221
+ ) -> torch.Tensor:
222
+ """read video using torchvision.io.read_video
223
+
224
+ Args:
225
+ ele (dict): a dict contains the configuration of video.
226
+ support keys:
227
+ - video: the path of video. support "file://", "http://", "https://" and local path.
228
+ - video_start: the start time of video.
229
+ - video_end: the end time of video.
230
+ Returns:
231
+ torch.Tensor: the video tensor with shape (T, C, H, W).
232
+ """
233
+ video_path = ele["video"]
234
+ if version.parse(torchvision.__version__) < version.parse("0.19.0"):
235
+ if "http://" in video_path or "https://" in video_path:
236
+ warnings.warn("torchvision < 0.19.0 does not support http/https video path, please upgrade to 0.19.0.")
237
+ if "file://" in video_path:
238
+ video_path = video_path[7:]
239
+ st = time.time()
240
+ video, audio, info = io.read_video(
241
+ video_path,
242
+ start_pts=ele.get("video_start", 0.0),
243
+ end_pts=ele.get("video_end", None),
244
+ pts_unit="sec",
245
+ output_format="TCHW",
246
+ )
247
+
248
+ total_frames, video_fps = video.size(0), info["video_fps"]
249
+ # logger.info(f"torchvision: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s")
250
+ if ele['sample_type'] == 'uniform':
251
+ nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
252
+ idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
253
+ elif ele['sample_type'] == 'multi_pts':
254
+ frames_each_pts = 6
255
+ num_pts = 4
256
+ fps = 8
257
+ nframes = int(total_frames * fps // video_fps)
258
+ frames_idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
259
+
260
+ start_pt = int(frames_each_pts // 2)
261
+ end_pt = int(nframes - frames_each_pts // 2 - 1)
262
+ pts = torch.linspace(start_pt, end_pt, num_pts).round().long().tolist()
263
+ idx = []
264
+ for pt in pts:
265
+ idx.extend(frames_idx[pt - frames_each_pts // 2 : pt + frames_each_pts // 2])
266
+
267
+ video = video[idx]
268
+ return video
269
+
270
+
271
+ def is_decord_available() -> bool:
272
+ import importlib.util
273
+
274
+ return importlib.util.find_spec("decord") is not None
275
+
276
+
277
+ def _read_video_decord(
278
+ ele: dict,
279
+ ) -> torch.Tensor:
280
+ """read video using decord.VideoReader
281
+
282
+ Args:
283
+ ele (dict): a dict contains the configuration of video.
284
+ support keys:
285
+ - video: the path of video. support "file://", "http://", "https://" and local path.
286
+ - video_start: the start time of video.
287
+ - video_end: the end time of video.
288
+ Returns:
289
+ torch.Tensor: the video tensor with shape (T, C, H, W).
290
+ """
291
+ import decord
292
+ video_path = ele["video"]
293
+ st = time.time()
294
+ vr = decord.VideoReader(video_path)
295
+ # TODO: support start_pts and end_pts
296
+ if 'video_start' in ele or 'video_end' in ele:
297
+ raise NotImplementedError("not support start_pts and end_pts in decord for now.")
298
+ total_frames, video_fps = len(vr), vr.get_avg_fps()
299
+ # logger.info(f"decord: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s")
300
+ if ele['sample_type'] == 'uniform':
301
+ nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
302
+ # nframes = max(nframes, 8)
303
+ # import pdb; pdb.set_trace()
304
+ idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
305
+ elif ele['sample_type'] == 'multi_pts':
306
+ frames_each_pts = 6
307
+ num_pts = 4
308
+ fps = 8
309
+ nframes = int(total_frames * fps // video_fps)
310
+ frames_idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
311
+
312
+ start_pt = int(frames_each_pts // 2)
313
+ end_pt = int(nframes - frames_each_pts // 2 - 1)
314
+ pts = torch.linspace(start_pt, end_pt, num_pts).round().long().tolist()
315
+ idx = []
316
+ for pt in pts:
317
+ idx.extend(frames_idx[pt - frames_each_pts // 2 : pt + frames_each_pts // 2])
318
+ video = vr.get_batch(idx).asnumpy()
319
+ video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format
320
+ return video
321
+
322
+
323
+ VIDEO_READER_BACKENDS = {
324
+ "decord": _read_video_decord,
325
+ "torchvision": _read_video_torchvision,
326
+ }
327
+
328
+ FORCE_QWENVL_VIDEO_READER = os.getenv("FORCE_QWENVL_VIDEO_READER", None)
329
+
330
+
331
+ @lru_cache(maxsize=1)
332
+ def get_video_reader_backend() -> str:
333
+ if FORCE_QWENVL_VIDEO_READER is not None:
334
+ video_reader_backend = FORCE_QWENVL_VIDEO_READER
335
+ elif is_decord_available():
336
+ video_reader_backend = "decord"
337
+ else:
338
+ video_reader_backend = "torchvision"
339
+ print(f"qwen-vl-utils using {video_reader_backend} to read video.", file=sys.stderr)
340
+ return video_reader_backend
341
+
342
+
343
+ def fetch_video(ele: dict, image_factor: int = IMAGE_FACTOR) -> torch.Tensor | list[Image.Image]:
344
+ if isinstance(ele["video"], str):
345
+ video_reader_backend = get_video_reader_backend()
346
+ video = VIDEO_READER_BACKENDS[video_reader_backend](ele)
347
+ # import pdb; pdb.set_trace()
348
+ nframes, _, height, width = video.shape
349
+
350
+ min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS)
351
+ total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS)
352
+ max_pixels = max(min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR), int(min_pixels * 1.05))
353
+ max_pixels = ele.get("max_pixels", max_pixels)
354
+ if "resized_height" in ele and "resized_width" in ele:
355
+ resized_height, resized_width = smart_resize(
356
+ ele["resized_height"],
357
+ ele["resized_width"],
358
+ factor=image_factor,
359
+ )
360
+ else:
361
+ resized_height, resized_width = smart_resize(
362
+ height,
363
+ width,
364
+ factor=image_factor,
365
+ min_pixels=min_pixels,
366
+ max_pixels=max_pixels,
367
+ )
368
+ video = transforms.functional.resize(
369
+ video,
370
+ [resized_height, resized_width],
371
+ interpolation=InterpolationMode.BICUBIC,
372
+ antialias=True,
373
+ ).float()
374
+ return video
375
+ else:
376
+ assert isinstance(ele["video"], (list, tuple))
377
+ process_info = ele.copy()
378
+ process_info.pop("type", None)
379
+ process_info.pop("video", None)
380
+ images = [
381
+ fetch_image({"image": video_element, **process_info}, size_factor=image_factor)
382
+ for video_element in ele["video"]
383
+ ]
384
+ nframes = ceil_by_factor(len(images), FRAME_FACTOR)
385
+ if len(images) < nframes:
386
+ images.extend([images[-1]] * (nframes - len(images)))
387
+ return images
388
+
389
+
390
+ def extract_vision_info(conversations: list[dict] | list[list[dict]]) -> list[dict]:
391
+ vision_infos = []
392
+ if isinstance(conversations[0], dict):
393
+ conversations = [conversations]
394
+ for conversation in conversations:
395
+ for message in conversation:
396
+ if isinstance(message["content"], list):
397
+ for ele in message["content"]:
398
+ if (
399
+ "image" in ele
400
+ or "image_url" in ele
401
+ or "video" in ele
402
+ or ele["type"] in ("image", "image_url", "video")
403
+ ):
404
+ vision_infos.append(ele)
405
+ return vision_infos
406
+
407
+
408
+ def process_vision_info(
409
+ conversations: list[dict] | list[list[dict]],
410
+ ) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | None]:
411
+ vision_infos = extract_vision_info(conversations)
412
+ ## Read images or videos
413
+ image_inputs = []
414
+ video_inputs = []
415
+ for vision_info in vision_infos:
416
+ if "image" in vision_info or "image_url" in vision_info:
417
+ image_inputs.append(fetch_image(vision_info))
418
+ elif "video" in vision_info:
419
+ video_inputs.append(fetch_video(vision_info))
420
+ else:
421
+ raise ValueError("image, image_url or video should in content.")
422
+ if len(image_inputs) == 0:
423
+ image_inputs = None
424
+ if len(video_inputs) == 0:
425
+ video_inputs = None
426
+ return image_inputs, video_inputs
hpsv3/inference.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ from collections.abc import Mapping
4
+ import torch
5
+ import huggingface_hub
6
+ from .dataset.utils import process_vision_info
7
+ from .dataset.data_collator_qwen import prompt_with_special_token, prompt_without_special_token, INSTRUCTION
8
+ from .utils.parser import ModelConfig, PEFTLoraConfig, TrainingConfig, DataConfig, parse_args_with_yaml
9
+ from .train import create_model_and_processor
10
+ from pathlib import Path
11
+
12
+ _MODEL_CONFIG_PATH = Path(__file__).parent / f"config/"
13
+
14
+ class HPSv3RewardInferencer():
15
+ def __init__(self, config_path=None, checkpoint_path=None, device='cuda', differentiable=False):
16
+ if config_path is None:
17
+ config_path = os.path.join(_MODEL_CONFIG_PATH, 'HPSv3_7B.yaml')
18
+
19
+ if checkpoint_path is None:
20
+ checkpoint_path = huggingface_hub.hf_hub_download("MizzenAI/HPSv3", 'HPSv3.safetensors', repo_type='model')
21
+
22
+ (data_config, training_args, model_config, peft_lora_config), config_path = (
23
+ parse_args_with_yaml(
24
+ (DataConfig, TrainingConfig, ModelConfig, PEFTLoraConfig), config_path, is_train=False
25
+ )
26
+ )
27
+ training_args.output_dir = os.path.join(
28
+ training_args.output_dir, config_path.split("/")[-1].split(".")[0]
29
+ )
30
+ model, processor, peft_config = create_model_and_processor(
31
+ model_config=model_config,
32
+ peft_lora_config=peft_lora_config,
33
+ training_args=training_args,
34
+ differentiable=differentiable,
35
+ )
36
+
37
+ self.device = device
38
+ self.use_special_tokens = model_config.use_special_tokens
39
+
40
+ if checkpoint_path.endswith('.safetensors'):
41
+ import safetensors.torch
42
+ state_dict = safetensors.torch.load_file(checkpoint_path, device="cpu")
43
+ else:
44
+ state_dict = torch.load(checkpoint_path , map_location="cpu")
45
+
46
+ if "model" in state_dict:
47
+ state_dict = state_dict["model"]
48
+ model.load_state_dict(state_dict, strict=True)
49
+ model.eval()
50
+
51
+ self.model = model
52
+ self.processor = processor
53
+
54
+ self.model.to(self.device)
55
+ self.data_config = data_config
56
+
57
+ def _pad_sequence(self, sequences, attention_mask, max_len, padding_side='right'):
58
+ """
59
+ Pad the sequences to the maximum length.
60
+ """
61
+ assert padding_side in ['right', 'left']
62
+ if sequences.shape[1] >= max_len:
63
+ return sequences, attention_mask
64
+
65
+ pad_len = max_len - sequences.shape[1]
66
+ padding = (0, pad_len) if padding_side == 'right' else (pad_len, 0)
67
+
68
+ sequences_padded = torch.nn.functional.pad(sequences, padding, 'constant', self.processor.tokenizer.pad_token_id)
69
+ attention_mask_padded = torch.nn.functional.pad(attention_mask, padding, 'constant', 0)
70
+
71
+ return sequences_padded, attention_mask_padded
72
+
73
+ def _prepare_input(self, data):
74
+ """
75
+ Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and
76
+ handling potential state.
77
+ """
78
+ if isinstance(data, Mapping):
79
+ return type(data)({k: self._prepare_input(v) for k, v in data.items()})
80
+ elif isinstance(data, (tuple, list)):
81
+ return type(data)(self._prepare_input(v) for v in data)
82
+ elif isinstance(data, torch.Tensor):
83
+ kwargs = {"device": self.device}
84
+ return data.to(**kwargs)
85
+ return data
86
+
87
+ def _prepare_inputs(self, inputs):
88
+ """
89
+ Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and
90
+ handling potential state.
91
+ """
92
+ inputs = self._prepare_input(inputs)
93
+ if len(inputs) == 0:
94
+ raise ValueError
95
+ return inputs
96
+
97
+ def prepare_batch(self, image_paths, prompts):
98
+ max_pixels = 256 * 28 * 28
99
+ min_pixels = 256 * 28 * 28
100
+ message_list = []
101
+ for text, image in zip(prompts, image_paths):
102
+ out_message = [
103
+ {
104
+ "role": "user",
105
+ "content": [
106
+ {
107
+ "type": "image",
108
+ "image": image,
109
+ "min_pixels": max_pixels,
110
+ "max_pixels": max_pixels,
111
+ },
112
+ {
113
+ "type": "text",
114
+ "text": (
115
+ INSTRUCTION.format(text_prompt=text)
116
+ + prompt_with_special_token
117
+ if self.use_special_tokens
118
+ else prompt_without_special_token
119
+ ),
120
+ },
121
+ ],
122
+ }
123
+ ]
124
+
125
+ message_list.append(out_message)
126
+
127
+ image_inputs, _ = process_vision_info(message_list)
128
+
129
+ batch = self.processor(
130
+ text=self.processor.apply_chat_template(message_list, tokenize=False, add_generation_prompt=True),
131
+ images=image_inputs,
132
+ padding=True,
133
+ return_tensors="pt",
134
+ videos_kwargs={"do_rescale": True},
135
+ )
136
+ batch = self._prepare_inputs(batch)
137
+ return batch
138
+
139
+ def reward(self, image_paths, prompts):
140
+
141
+ batch = self.prepare_batch(image_paths, prompts)
142
+ rewards = self.model(
143
+ return_dict=True,
144
+ **batch
145
+ )["logits"]
146
+
147
+ return rewards
148
+
149
+
150
+ if __name__ == "__main__":
151
+ config_path = 'config/inference/HPSv3_7B.yaml'
152
+ checkpoint_path = 'checkpoints/HPSv3_7B.pth'
153
+ device = 'cuda'
154
+ dtype = torch.bfloat16
155
+ inferencer = HPSv3RewardInferencer(config_path, checkpoint_path, device=device)
156
+
157
+ image_paths = [
158
+ "assets/example1.png",
159
+ "assets/example2.png"
160
+ ]
161
+ prompts = [
162
+ "cute chibi anime cartoon fox, smiling wagging tail with a small cartoon heart above sticker",
163
+ "cute chibi anime cartoon fox, smiling wagging tail with a small cartoon heart above sticker"
164
+ ]
165
+ rewards = inferencer.reward(image_paths, prompts)
166
+ print(rewards[0][0].item()) # miu and sigma. we select miu as the final output
167
+ print(rewards[1][0].item())
hpsv3/model/differentiable_image_processor.py ADDED
@@ -0,0 +1,629 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """Image processor class for Qwen2-VL.
21
+
22
+ This module provides both differentiable and non-differentiable image processing methods:
23
+
24
+ 1. For DIFFERENTIABLE processing (torch.autograd compatible):
25
+ - Pass torch.Tensor to _preprocess() method
26
+ - Use preprocess_tensor() method directly
27
+ - All operations use PyTorch functions (F.interpolate, tensor operations, etc.)
28
+
29
+ 2. For NON-DIFFERENTIABLE processing (original functionality):
30
+ - Pass PIL images or numpy arrays to preprocess() method
31
+ - Uses PIL/transformers image processing functions and numpy operations
32
+
33
+ The differentiable path supports:
34
+ - Bilinear interpolation for resizing (instead of PIL resampling)
35
+ - Tensor-based rescaling and normalization
36
+ - Differentiable patch extraction and reshaping
37
+ """
38
+
39
+ import math
40
+ from typing import Dict, List, Optional, Union
41
+
42
+ import numpy as np
43
+ import torch
44
+ import torch.nn.functional as F
45
+
46
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
47
+ from transformers.image_transforms import (
48
+ convert_to_rgb,
49
+ resize,
50
+ to_channel_dimension_format,
51
+ )
52
+ from transformers.image_utils import (
53
+ OPENAI_CLIP_MEAN,
54
+ OPENAI_CLIP_STD,
55
+ ChannelDimension,
56
+ ImageInput,
57
+ PILImageResampling,
58
+ VideoInput,
59
+ get_image_size,
60
+ infer_channel_dimension_format,
61
+ is_scaled_image,
62
+ is_valid_image,
63
+ make_list_of_images,
64
+ to_numpy_array,
65
+ valid_images,
66
+ validate_preprocess_arguments,
67
+ )
68
+ from transformers.utils import TensorType, is_vision_available, logging
69
+
70
+
71
+ logger = logging.get_logger(__name__)
72
+
73
+
74
+ if is_vision_available():
75
+ from PIL import Image
76
+
77
+
78
+ def make_batched_images(images) -> List[List[ImageInput]]:
79
+ """
80
+ Accepts images in list or nested list format, and makes a list of images for preprocessing.
81
+
82
+ Args:
83
+ images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`):
84
+ The input image.
85
+
86
+ Returns:
87
+ list: A list of images.
88
+ """
89
+ if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]):
90
+ return [img for img_list in images for img in img_list]
91
+
92
+ elif isinstance(images, (list, tuple)) and is_valid_image(images[0]):
93
+ return images
94
+
95
+ elif is_valid_image(images):
96
+ return [images]
97
+
98
+ raise ValueError(f"Could not make batched images from {images}")
99
+
100
+
101
+ # Copied from transformers.models.llava_next_video.image_processing_llava_next_video.make_batched_videos
102
+ def make_batched_videos(videos) -> List[VideoInput]:
103
+ if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]):
104
+ return videos
105
+
106
+ elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]):
107
+ if isinstance(videos[0], Image.Image):
108
+ return [videos]
109
+ elif len(videos[0].shape) == 4:
110
+ return [list(video) for video in videos]
111
+
112
+ elif is_valid_image(videos) and len(videos.shape) == 4:
113
+ return [list(videos)]
114
+
115
+ raise ValueError(f"Could not make batched video from {videos}")
116
+
117
+
118
+ def smart_resize(
119
+ height: int, width: int, factor: int = 28, min_pixels: int = 56 * 56, max_pixels: int = 14 * 14 * 4 * 1280
120
+ ):
121
+ """Rescales the image so that the following conditions are met:
122
+
123
+ 1. Both dimensions (height and width) are divisible by 'factor'.
124
+
125
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
126
+
127
+ 3. The aspect ratio of the image is maintained as closely as possible.
128
+
129
+ """
130
+ if height < factor or width < factor:
131
+ raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
132
+ elif max(height, width) / min(height, width) > 200:
133
+ raise ValueError(
134
+ f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
135
+ )
136
+ h_bar = round(height / factor) * factor
137
+ w_bar = round(width / factor) * factor
138
+ if h_bar * w_bar > max_pixels:
139
+ beta = math.sqrt((height * width) / max_pixels)
140
+ h_bar = math.floor(height / beta / factor) * factor
141
+ w_bar = math.floor(width / beta / factor) * factor
142
+ elif h_bar * w_bar < min_pixels:
143
+ beta = math.sqrt(min_pixels / (height * width))
144
+ h_bar = math.ceil(height * beta / factor) * factor
145
+ w_bar = math.ceil(width * beta / factor) * factor
146
+ return h_bar, w_bar
147
+
148
+
149
+ class Qwen2VLImageProcessor(BaseImageProcessor):
150
+ r"""
151
+ Constructs a Qwen2-VL image processor that dynamically resizes images based on the original images.
152
+
153
+ Args:
154
+ do_resize (`bool`, *optional*, defaults to `True`):
155
+ Whether to resize the image's (height, width) dimensions.
156
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
157
+ Resampling filter to use when resizing the image.
158
+ do_rescale (`bool`, *optional*, defaults to `True`):
159
+ Whether to rescale the image by the specified scale `rescale_factor`.
160
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
161
+ Scale factor to use if rescaling the image.
162
+ do_normalize (`bool`, *optional*, defaults to `True`):
163
+ Whether to normalize the image.
164
+ image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
165
+ Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
166
+ image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
167
+ Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
168
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
169
+ Whether to convert the image to RGB.
170
+ min_pixels (`int`, *optional*, defaults to `56 * 56`):
171
+ The min pixels of the image to resize the image.
172
+ max_pixels (`int`, *optional*, defaults to `28 * 28 * 1280`):
173
+ The max pixels of the image to resize the image.
174
+ patch_size (`int`, *optional*, defaults to 14):
175
+ The spacial patch size of the vision encoder.
176
+ temporal_patch_size (`int`, *optional*, defaults to 2):
177
+ The temporal patch size of the vision encoder.
178
+ merge_size (`int`, *optional*, defaults to 2):
179
+ The merge size of the vision encoder to llm encoder.
180
+ """
181
+
182
+ model_input_names = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"]
183
+
184
+ def __init__(
185
+ self,
186
+ do_resize: bool = True,
187
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
188
+ do_rescale: bool = True,
189
+ rescale_factor: Union[int, float] = 1 / 255,
190
+ do_normalize: bool = True,
191
+ image_mean: Optional[Union[float, List[float]]] = None,
192
+ image_std: Optional[Union[float, List[float]]] = None,
193
+ do_convert_rgb: bool = True,
194
+ min_pixels: int = 56 * 56,
195
+ max_pixels: int = 28 * 28 * 1280,
196
+ patch_size: int = 14,
197
+ temporal_patch_size: int = 2,
198
+ merge_size: int = 2,
199
+ **kwargs,
200
+ ) -> None:
201
+ super().__init__(**kwargs)
202
+ self.do_resize = do_resize
203
+ self.resample = resample
204
+ self.do_rescale = do_rescale
205
+ self.rescale_factor = rescale_factor
206
+ self.do_normalize = do_normalize
207
+ self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
208
+ self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
209
+ self.min_pixels = min_pixels
210
+ self.max_pixels = max_pixels
211
+ self.patch_size = patch_size
212
+ self.temporal_patch_size = temporal_patch_size
213
+ self.merge_size = merge_size
214
+ self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels}
215
+ self.do_convert_rgb = do_convert_rgb
216
+
217
+ def _preprocess_differentiable(
218
+ self,
219
+ images: torch.Tensor,
220
+ do_resize: bool = None,
221
+ do_rescale: bool = None,
222
+ rescale_factor: float = None,
223
+ do_normalize: bool = None,
224
+ image_mean: Optional[Union[float, List[float]]] = None,
225
+ image_std: Optional[Union[float, List[float]]] = None,
226
+ ):
227
+ """
228
+ Differentiable version of image preprocessing using torch operations.
229
+
230
+ Args:
231
+ images: torch.Tensor of shape (B, C, H, W) or (C, H, W)
232
+ Returns:
233
+ flatten_patches: torch.Tensor - flattened patches
234
+ grid_thw: tuple - (grid_t, grid_h, grid_w)
235
+ """
236
+ if images.dim() == 3:
237
+ images = images.unsqueeze(0) # Add batch dimension
238
+
239
+ batch_size, channels, height, width = images.shape
240
+
241
+ processed_images = []
242
+ resized_height, resized_width = height, width
243
+
244
+ for i in range(batch_size):
245
+ image = images[i] # (C, H, W)
246
+
247
+ if do_resize:
248
+ resized_height, resized_width = smart_resize(
249
+ height,
250
+ width,
251
+ factor=self.patch_size * self.merge_size,
252
+ min_pixels=self.min_pixels,
253
+ max_pixels=self.max_pixels,
254
+ )
255
+ # Use differentiable interpolation
256
+ image = F.interpolate(
257
+ image.unsqueeze(0),
258
+ size=(resized_height, resized_width),
259
+ mode='bilinear',
260
+ align_corners=False
261
+ ).squeeze(0)
262
+
263
+ if do_rescale:
264
+ image = image * rescale_factor
265
+
266
+ if do_normalize:
267
+ if isinstance(image_mean, (list, tuple)):
268
+ mean = torch.tensor(image_mean, device=image.device, dtype=image.dtype).view(-1, 1, 1)
269
+ std = torch.tensor(image_std, device=image.device, dtype=image.dtype).view(-1, 1, 1)
270
+ else:
271
+ mean = image_mean
272
+ std = image_std
273
+ image = (image - mean) / std
274
+
275
+ processed_images.append(image)
276
+
277
+ # Stack all processed images
278
+ patches = torch.stack(processed_images) # (B, C, H, W)
279
+
280
+ # Handle temporal dimension
281
+ if patches.shape[0] == 1:
282
+ patches = patches.repeat(self.temporal_patch_size, 1, 1, 1)
283
+
284
+ # Reshape for patch extraction
285
+ batch_size, channel, resized_height, resized_width = patches.shape
286
+ grid_t = batch_size // self.temporal_patch_size
287
+ grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
288
+
289
+ # Differentiable patch extraction and reshaping
290
+ patches = patches.view(
291
+ grid_t,
292
+ self.temporal_patch_size,
293
+ channel,
294
+ grid_h // self.merge_size,
295
+ self.merge_size,
296
+ self.patch_size,
297
+ grid_w // self.merge_size,
298
+ self.merge_size,
299
+ self.patch_size,
300
+ )
301
+ patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8)
302
+ flatten_patches = patches.reshape(
303
+ grid_t * grid_h * grid_w, channel * self.temporal_patch_size * self.patch_size * self.patch_size
304
+ )
305
+
306
+ return flatten_patches, (grid_t, grid_h, grid_w)
307
+
308
+ def _preprocess(
309
+ self,
310
+ images: Union[ImageInput, VideoInput],
311
+ do_resize: bool = None,
312
+ resample: PILImageResampling = None,
313
+ do_rescale: bool = None,
314
+ rescale_factor: float = None,
315
+ do_normalize: bool = None,
316
+ image_mean: Optional[Union[float, List[float]]] = None,
317
+ image_std: Optional[Union[float, List[float]]] = None,
318
+ do_convert_rgb: bool = None,
319
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
320
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
321
+ ):
322
+ """
323
+ Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
324
+
325
+ Args:
326
+ images (`ImageInput`):
327
+ Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
328
+ vision_info (`List[Dict]`, *optional*):
329
+ Optional list of dictionaries containing additional information about vision inputs.
330
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
331
+ Whether to resize the image.
332
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
333
+ Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums.
334
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
335
+ Whether to rescale the image.
336
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
337
+ Scale factor to use if rescaling the image.
338
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
339
+ Whether to normalize the image.
340
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
341
+ Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
342
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
343
+ Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
344
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
345
+ Whether to convert the image to RGB.
346
+ data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
347
+ The channel dimension format for the output image. Can be one of:
348
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
349
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
350
+ - Unset: Use the channel dimension format of the input image.
351
+ input_data_format (`ChannelDimension` or `str`, *optional*):
352
+ The channel dimension format for the input image. Can be one of:
353
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
354
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
355
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
356
+ """
357
+ # Check if input is already a torch tensor (differentiable path)
358
+ if isinstance(images, torch.Tensor):
359
+ return self._preprocess_differentiable(
360
+ images,
361
+ do_resize=do_resize,
362
+ do_rescale=do_rescale,
363
+ rescale_factor=rescale_factor,
364
+ do_normalize=do_normalize,
365
+ image_mean=image_mean,
366
+ image_std=image_std,
367
+ )
368
+
369
+ # Original non-differentiable path for backward compatibility
370
+ images = make_list_of_images(images)
371
+
372
+ if do_convert_rgb:
373
+ images = [convert_to_rgb(image) for image in images]
374
+
375
+ # All transformations expect numpy arrays.
376
+ images = [to_numpy_array(image) for image in images]
377
+
378
+ if is_scaled_image(images[0]) and do_rescale:
379
+ logger.warning_once(
380
+ "It looks like you are trying to rescale already rescaled images. If the input"
381
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
382
+ )
383
+ if input_data_format is None:
384
+ # We assume that all images have the same channel dimension format.
385
+ input_data_format = infer_channel_dimension_format(images[0])
386
+
387
+ height, width = get_image_size(images[0], channel_dim=input_data_format)
388
+ resized_height, resized_width = height, width
389
+ processed_images = []
390
+ for image in images:
391
+ if do_resize:
392
+ resized_height, resized_width = smart_resize(
393
+ height,
394
+ width,
395
+ factor=self.patch_size * self.merge_size,
396
+ min_pixels=self.min_pixels,
397
+ max_pixels=self.max_pixels,
398
+ )
399
+ image = resize(
400
+ image, size=(resized_height, resized_width), resample=resample, input_data_format=input_data_format
401
+ )
402
+
403
+ if do_rescale:
404
+ image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format)
405
+
406
+ if do_normalize:
407
+ image = self.normalize(
408
+ image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
409
+ )
410
+
411
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
412
+ processed_images.append(image)
413
+
414
+ # NOTE: The following operations use numpy and are NOT differentiable
415
+ # For differentiable operations, pass torch.Tensor as input to use _preprocess_differentiable
416
+ patches = np.array(processed_images)
417
+ if data_format == ChannelDimension.LAST:
418
+ patches = patches.transpose(0, 3, 1, 2)
419
+ if patches.shape[0] == 1:
420
+ patches = np.tile(patches, (self.temporal_patch_size, 1, 1, 1))
421
+ channel = patches.shape[1]
422
+ grid_t = patches.shape[0] // self.temporal_patch_size
423
+ grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
424
+ patches = patches.reshape(
425
+ grid_t,
426
+ self.temporal_patch_size,
427
+ channel,
428
+ grid_h // self.merge_size,
429
+ self.merge_size,
430
+ self.patch_size,
431
+ grid_w // self.merge_size,
432
+ self.merge_size,
433
+ self.patch_size,
434
+ )
435
+ patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8)
436
+ flatten_patches = patches.reshape(
437
+ grid_t * grid_h * grid_w, channel * self.temporal_patch_size * self.patch_size * self.patch_size
438
+ )
439
+
440
+ return flatten_patches, (grid_t, grid_h, grid_w)
441
+
442
+ def preprocess_tensor(
443
+ self,
444
+ images: torch.Tensor,
445
+ do_resize: bool = None,
446
+ do_rescale: bool = None,
447
+ rescale_factor: float = None,
448
+ do_normalize: bool = None,
449
+ image_mean: Optional[Union[float, List[float]]] = None,
450
+ image_std: Optional[Union[float, List[float]]] = None,
451
+ ):
452
+ """
453
+ Differentiable preprocessing method for torch tensors.
454
+
455
+ Args:
456
+ images: torch.Tensor of shape (B, C, H, W) or (C, H, W)
457
+
458
+ Returns:
459
+ dict containing:
460
+ - pixel_values: torch.Tensor - processed patches
461
+ - image_grid_thw: torch.Tensor - grid dimensions
462
+ """
463
+ do_resize = do_resize if do_resize is not None else self.do_resize
464
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
465
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
466
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
467
+ image_mean = image_mean if image_mean is not None else self.image_mean
468
+ image_std = image_std if image_std is not None else self.image_std
469
+
470
+ patches, image_grid_thw = self._preprocess_differentiable(
471
+ images,
472
+ do_resize=do_resize,
473
+ do_rescale=do_rescale,
474
+ rescale_factor=rescale_factor,
475
+ do_normalize=do_normalize,
476
+ image_mean=image_mean,
477
+ image_std=image_std,
478
+ )
479
+
480
+ return {
481
+ "pixel_values": patches,
482
+ "image_grid_thw": torch.tensor(image_grid_thw, device=patches.device)
483
+ }
484
+
485
+ def preprocess(
486
+ self,
487
+ images: ImageInput,
488
+ videos: VideoInput = None,
489
+ do_resize: bool = None,
490
+ size: Dict[str, int] = None,
491
+ resample: PILImageResampling = None,
492
+ do_rescale: bool = None,
493
+ rescale_factor: float = None,
494
+ do_normalize: bool = None,
495
+ image_mean: Optional[Union[float, List[float]]] = None,
496
+ image_std: Optional[Union[float, List[float]]] = None,
497
+ do_convert_rgb: bool = None,
498
+ return_tensors: Optional[Union[str, TensorType]] = None,
499
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
500
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
501
+ ):
502
+ """
503
+ Args:
504
+ images (`ImageInput`):
505
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
506
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
507
+ videos (`VideoInput`):
508
+ Video to preprocess. Expects a single or batch of videos with pixel values ranging from 0 to 255. If
509
+ passing in videos with pixel values between 0 and 1, set `do_rescale=False`.
510
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
511
+ Whether to resize the image.
512
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
513
+ Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
514
+ the longest edge resized to keep the input aspect ratio.
515
+ resample (`int`, *optional*, defaults to `self.resample`):
516
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
517
+ has an effect if `do_resize` is set to `True`.
518
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
519
+ Whether to rescale the image.
520
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
521
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
522
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
523
+ Whether to normalize the image.
524
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
525
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
526
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
527
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
528
+ `True`.
529
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
530
+ Whether to convert the image to RGB.
531
+ return_tensors (`str` or `TensorType`, *optional*):
532
+ The type of tensors to return. Can be one of:
533
+ - Unset: Return a list of `np.ndarray`.
534
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
535
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
536
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
537
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
538
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
539
+ The channel dimension format for the output image. Can be one of:
540
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
541
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
542
+ - Unset: Use the channel dimension format of the input image.
543
+ input_data_format (`ChannelDimension` or `str`, *optional*):
544
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
545
+ from the input image. Can be one of:
546
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
547
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
548
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
549
+
550
+ """
551
+ do_resize = do_resize if do_resize is not None else self.do_resize
552
+ size = size if size is not None else self.size
553
+ resample = resample if resample is not None else self.resample
554
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
555
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
556
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
557
+ image_mean = image_mean if image_mean is not None else self.image_mean
558
+ image_std = image_std if image_std is not None else self.image_std
559
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
560
+
561
+ if images is not None:
562
+ images = make_batched_images(images)
563
+ if videos is not None:
564
+ videos = make_batched_videos(videos)
565
+
566
+ if images is not None and not valid_images(images):
567
+ raise ValueError(
568
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
569
+ "torch.Tensor, tf.Tensor or jax.ndarray."
570
+ )
571
+
572
+ validate_preprocess_arguments(
573
+ rescale_factor=rescale_factor,
574
+ do_normalize=do_normalize,
575
+ image_mean=image_mean,
576
+ image_std=image_std,
577
+ do_resize=do_resize,
578
+ size=size,
579
+ resample=resample,
580
+ )
581
+
582
+ if images is not None:
583
+ pixel_values, vision_grid_thws = [], []
584
+ for image in images:
585
+ patches, image_grid_thw = self._preprocess(
586
+ image,
587
+ do_resize=do_resize,
588
+ resample=resample,
589
+ do_rescale=do_rescale,
590
+ rescale_factor=rescale_factor,
591
+ do_normalize=do_normalize,
592
+ image_mean=image_mean,
593
+ image_std=image_std,
594
+ data_format=data_format,
595
+ do_convert_rgb=do_convert_rgb,
596
+ input_data_format=input_data_format,
597
+ )
598
+ pixel_values.extend(patches)
599
+ vision_grid_thws.append(image_grid_thw)
600
+ if not isinstance(pixel_values[0], torch.Tensor):
601
+ pixel_values = np.array(pixel_values)
602
+ else:
603
+ pixel_values = torch.stack(pixel_values)
604
+ vision_grid_thws = np.array(vision_grid_thws)
605
+ data = {"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws}
606
+
607
+ if videos is not None:
608
+ pixel_values, vision_grid_thws = [], []
609
+ for images in videos:
610
+ patches, video_grid_thw = self._preprocess(
611
+ images,
612
+ do_resize=do_resize,
613
+ resample=resample,
614
+ do_rescale=do_rescale,
615
+ rescale_factor=rescale_factor,
616
+ do_normalize=do_normalize,
617
+ image_mean=image_mean,
618
+ image_std=image_std,
619
+ data_format=data_format,
620
+ do_convert_rgb=do_convert_rgb,
621
+ input_data_format=input_data_format,
622
+ )
623
+ pixel_values.extend(patches)
624
+ vision_grid_thws.append(video_grid_thw)
625
+ pixel_values = np.array(pixel_values)
626
+ vision_grid_thws = np.array(vision_grid_thws)
627
+ data = {"pixel_values_videos": pixel_values, "video_grid_thw": vision_grid_thws}
628
+
629
+ return BatchFeature(data=data, tensor_type=return_tensors)
hpsv3/model/qwen2vl_trainer.py ADDED
@@ -0,0 +1,971 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pdb
3
+ import warnings
4
+ import time
5
+ import math
6
+ import json
7
+ from PIL import Image
8
+ import matplotlib.pyplot as plt
9
+ import matplotlib.patches as patches
10
+ from torch.utils.tensorboard import SummaryWriter
11
+ import torchvision.transforms as transforms
12
+
13
+ from typing import List, Optional, Dict, Union, Any
14
+ import pandas as pd
15
+ import safetensors
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn as nn
19
+ import datasets
20
+ from torch.utils.data import Dataset, DataLoader
21
+ from peft import PeftModel
22
+ from transformers import Qwen2VLForConditionalGeneration
23
+ from transformers import AutoConfig
24
+ from transformers.modeling_utils import PreTrainedModel
25
+ from transformers.trainer import TrainerCallback
26
+ from transformers.trainer import (
27
+ is_sagemaker_mp_enabled,
28
+ is_peft_available,
29
+ is_datasets_available,
30
+ WEIGHTS_NAME,
31
+ TRAINING_ARGS_NAME,
32
+ SAFE_WEIGHTS_NAME,
33
+ TRAINER_STATE_NAME,
34
+ PREFIX_CHECKPOINT_DIR,
35
+ logger,
36
+ speed_metrics,
37
+ deepspeed_init,
38
+ speed_metrics,
39
+ has_length,
40
+ EvalPrediction,
41
+ EvalLoopContainer,
42
+ PredictionOutput,
43
+ is_torch_xla_available,
44
+ denumpify_detensorize,
45
+ PredictionOutput,
46
+ EvalLoopOutput,
47
+ DistributedTensorGatherer,
48
+ SequentialDistributedSampler,
49
+ nested_concat,
50
+ )
51
+ from transformers.trainer_pt_utils import IterableDatasetShard
52
+ from transformers.trainer_callback import TrainerControl, TrainerState
53
+
54
+ from transformers.trainer_pt_utils import nested_detach, find_batch_size
55
+ from transformers.training_args import TrainingArguments
56
+ from trl import RewardTrainer
57
+ from hpsv3.utils.training_utils import get_peft_state_non_lora_maybe_zero_3
58
+
59
+ class Qwen2VLRewardModelBT(Qwen2VLForConditionalGeneration):
60
+ def __init__(
61
+ self,
62
+ config,
63
+ output_dim=4,
64
+ reward_token="last",
65
+ special_token_ids=None,
66
+ rm_head_type="default",
67
+ rm_head_kwargs=None,
68
+ ):
69
+ super().__init__(config)
70
+ # pdb.set_trace()
71
+ self.output_dim = output_dim
72
+ if rm_head_type == "default":
73
+ self.rm_head = nn.Linear(config.hidden_size, output_dim, bias=False)
74
+ elif rm_head_type == "ranknet":
75
+ if rm_head_kwargs is not None:
76
+ for layer in range(rm_head_kwargs.get("num_layers", 3)):
77
+ if layer == 0:
78
+ self.rm_head = nn.Sequential(
79
+ nn.Linear(config.hidden_size, rm_head_kwargs["hidden_size"]),
80
+ nn.ReLU(),
81
+ nn.Dropout(rm_head_kwargs.get("dropout", 0.1)),
82
+ )
83
+ elif layer < rm_head_kwargs.get("num_layers", 3) - 1:
84
+ self.rm_head.add_module(
85
+ f"layer_{layer}",
86
+ nn.Sequential(
87
+ nn.Linear(rm_head_kwargs["hidden_size"], rm_head_kwargs["hidden_size"]),
88
+ nn.ReLU(),
89
+ nn.Dropout(rm_head_kwargs.get("dropout", 0.1)),
90
+ ),
91
+ )
92
+ else:
93
+ self.rm_head.add_module(
94
+ f"output_layer",
95
+ nn.Linear(rm_head_kwargs["hidden_size"], output_dim, bias=rm_head_kwargs.get("bias", False)),
96
+ )
97
+
98
+ else:
99
+ self.rm_head = nn.Sequential(
100
+ nn.Linear(config.hidden_size, 1024),
101
+ nn.ReLU(),
102
+ nn.Dropout(0.05),
103
+ nn.Linear(1024, 16),
104
+ nn.ReLU(),
105
+ nn.Linear(16, output_dim),
106
+ )
107
+
108
+ self.rm_head.to(torch.float32)
109
+ self.reward_token = reward_token
110
+
111
+ self.special_token_ids = special_token_ids
112
+ if self.special_token_ids is not None:
113
+ self.reward_token = "special"
114
+
115
+ def forward(
116
+ self,
117
+ input_ids: torch.LongTensor = None,
118
+ attention_mask: Optional[torch.Tensor] = None,
119
+ position_ids: Optional[torch.LongTensor] = None,
120
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
121
+ inputs_embeds: Optional[torch.FloatTensor] = None,
122
+ labels: Optional[torch.LongTensor] = None,
123
+ use_cache: Optional[bool] = None,
124
+ output_attentions: Optional[bool] = None,
125
+ output_hidden_states: Optional[bool] = None,
126
+ return_dict: Optional[bool] = None,
127
+ pixel_values: Optional[torch.Tensor] = None,
128
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
129
+ image_grid_thw: Optional[torch.LongTensor] = None,
130
+ video_grid_thw: Optional[torch.LongTensor] = None,
131
+ rope_deltas: Optional[torch.LongTensor] = None,
132
+ ):
133
+ ## modified from the origin class Qwen2VLForConditionalGeneration
134
+ output_attentions = (
135
+ output_attentions
136
+ if output_attentions is not None
137
+ else self.config.output_attentions
138
+ )
139
+ output_hidden_states = (
140
+ output_hidden_states
141
+ if output_hidden_states is not None
142
+ else self.config.output_hidden_states
143
+ )
144
+ return_dict = (
145
+ return_dict if return_dict is not None else self.config.use_return_dict
146
+ )
147
+ # pdb.set_trace()
148
+ if inputs_embeds is None:
149
+ inputs_embeds = self.model.embed_tokens(input_ids)
150
+ if pixel_values is not None:
151
+ pixel_values = pixel_values.type(self.visual.get_dtype())
152
+ image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
153
+ image_mask = (
154
+ (input_ids == self.config.image_token_id)
155
+ .unsqueeze(-1)
156
+ .expand_as(inputs_embeds)
157
+ )
158
+ image_embeds = image_embeds.to(
159
+ inputs_embeds.device, inputs_embeds.dtype
160
+ )
161
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
162
+
163
+ if pixel_values_videos is not None:
164
+ pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
165
+ video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
166
+ video_mask = (
167
+ (input_ids == self.config.video_token_id)
168
+ .unsqueeze(-1)
169
+ .expand_as(inputs_embeds)
170
+ )
171
+ video_embeds = video_embeds.to(
172
+ inputs_embeds.device, inputs_embeds.dtype
173
+ )
174
+ inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
175
+
176
+ if attention_mask is not None:
177
+ attention_mask = attention_mask.to(inputs_embeds.device)
178
+
179
+ outputs = self.model(
180
+ input_ids=None,
181
+ position_ids=position_ids,
182
+ attention_mask=attention_mask,
183
+ past_key_values=past_key_values,
184
+ inputs_embeds=inputs_embeds,
185
+ use_cache=use_cache,
186
+ output_attentions=output_attentions,
187
+ output_hidden_states=output_hidden_states,
188
+ return_dict=return_dict,
189
+ )
190
+
191
+ hidden_states = outputs[0] # [B, L, D]
192
+ with torch.autocast(device_type='cuda', dtype=torch.float32):
193
+ logits = self.rm_head(hidden_states) # [B, L, N]
194
+
195
+ if input_ids is not None:
196
+ batch_size = input_ids.shape[0]
197
+ else:
198
+ batch_size = inputs_embeds.shape[0]
199
+
200
+ ## get sequence length
201
+ if self.config.pad_token_id is None and batch_size != 1:
202
+ raise ValueError(
203
+ "Cannot handle batch sizes > 1 if no padding token is defined."
204
+ )
205
+ if self.config.pad_token_id is None:
206
+ sequence_lengths = -1
207
+ else:
208
+ if input_ids is not None:
209
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
210
+ sequence_lengths = (
211
+ torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
212
+ )
213
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
214
+ sequence_lengths = sequence_lengths.to(logits.device)
215
+ else:
216
+ sequence_lengths = -1
217
+
218
+ ## get the last token's logits
219
+ if self.reward_token == "last":
220
+ pooled_logits = logits[
221
+ torch.arange(batch_size, device=logits.device), sequence_lengths
222
+ ]
223
+ elif self.reward_token == "mean":
224
+ ## get the mean of all valid tokens' logits
225
+ valid_lengths = torch.clamp(sequence_lengths, min=0, max=logits.size(1) - 1)
226
+ pooled_logits = torch.stack(
227
+ [logits[i, : valid_lengths[i]].mean(dim=0) for i in range(batch_size)]
228
+ )
229
+ elif self.reward_token == "special":
230
+ # special_token_ids = self.tokenizer.convert_tokens_to_ids(self.special_tokens)
231
+ # create a mask for special tokens
232
+ special_token_mask = torch.zeros_like(input_ids, dtype=torch.bool)
233
+ for special_token_id in self.special_token_ids:
234
+ special_token_mask = special_token_mask | (
235
+ input_ids == special_token_id
236
+ )
237
+ pooled_logits = logits[special_token_mask, ...]
238
+ pooled_logits = pooled_logits.view(
239
+ batch_size, 1, -1
240
+ ) # [B, 3, N] assert 3 attributes
241
+ pooled_logits = pooled_logits.view(batch_size, -1)
242
+
243
+ # pdb.set_trace()
244
+ else:
245
+ raise ValueError("Invalid reward_token")
246
+
247
+ return {"logits": pooled_logits}
248
+
249
+
250
+ def _convert_A_B_to_chosen_rejected(
251
+ rewards_A,
252
+ rewards_B,
253
+ tied_threshold=None,
254
+ choice_dist=None,
255
+ ):
256
+ """
257
+ Inputs:
258
+ rewards_A: [B, 1]
259
+ rewards_B: [B, 1]
260
+ Outputs:
261
+ rewards_chosen: [B, 1]
262
+ rewards_rejected: [B, 1]
263
+ nontied_mask: [B, 1] (preference labels that is not tied)
264
+ """
265
+ chosen_label = torch.ones_like(rewards_A, dtype=torch.int64).to(
266
+ rewards_A.device
267
+ ) # [B, 1]
268
+ chosen_mask = chosen_label == 1
269
+ rejected_mask = chosen_label != 1
270
+
271
+ rewards_chosen = rewards_A
272
+ rewards_rejected = rewards_B
273
+
274
+ if tied_threshold is None:
275
+ nontied_mask = torch.ones_like(chosen_label, dtype=torch.float32).to(
276
+ rewards_A.device
277
+ )
278
+ else:
279
+ nontied_mask = (
280
+ torch.abs(
281
+ (choice_dist[:, 0] - choice_dist[:, 1]) / torch.sum(choice_dist, dim=-1)
282
+ )
283
+ > tied_threshold
284
+ )
285
+ print(nontied_mask)
286
+ return (
287
+ rewards_chosen,
288
+ rewards_rejected,
289
+ nontied_mask,
290
+ )
291
+
292
+
293
+ class PartialEmbeddingUpdateCallback(TrainerCallback):
294
+ """
295
+ Callback to update the embedding of special tokens
296
+ Only the special tokens are updated, the rest of the embeddings are kept fixed
297
+ """
298
+
299
+ def __init__(self, special_token_ids):
300
+ super().__init__()
301
+ self.special_token_ids = special_token_ids
302
+ self.orig_embeds_params = None
303
+
304
+ def on_train_begin(self, args, state, control, **kwargs):
305
+ model = kwargs.get("model")
306
+ self.orig_embeds_params = model.get_input_embeddings().weight.clone().detach()
307
+
308
+ def on_step_end(self, args, state, control, **kwargs):
309
+ # pdb.set_trace()
310
+ model = kwargs.get("model")
311
+ tokenizer = kwargs.get("tokenizer")
312
+
313
+ index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool)
314
+ index_no_updates[self.special_token_ids] = False
315
+ with torch.no_grad():
316
+ model.get_input_embeddings().weight[index_no_updates] = (
317
+ self.orig_embeds_params[index_no_updates]
318
+ )
319
+
320
+
321
+ class VLMRewardTrainer(RewardTrainer):
322
+ def __init__(self, loss_type="regular", loss_hyperparameters={}, tied_threshold=None,
323
+ visualization_steps=500, max_viz_samples=4, *args, **kwargs):
324
+ super(VLMRewardTrainer, self).__init__(*args, **kwargs)
325
+ self.loss_type = loss_type
326
+ self.tied_threshold = tied_threshold
327
+ self.rewards_chosen_accumulated = []
328
+ self.rewards_rejected_accumulated = []
329
+ self.loss_hyperparameters = loss_hyperparameters
330
+ self.visualization_steps = visualization_steps
331
+ self.max_viz_samples = max_viz_samples
332
+
333
+ def get_eval_dataloader(
334
+ self, eval_dataset: Optional[Union[str, Dataset]] = None
335
+ ) -> DataLoader:
336
+ """
337
+ Returns the evaluation [`~torch.utils.data.DataLoader`].
338
+
339
+ Subclass and override this method if you want to inject some custom behavior.
340
+
341
+ Args:
342
+ eval_dataset (`str` or `torch.utils.data.Dataset`, *optional*):
343
+ If a `str`, will use `self.eval_dataset[eval_dataset]` as the evaluation dataset. If a `Dataset`, will override `self.eval_dataset` and must implement `__len__`. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed.
344
+ """
345
+ if eval_dataset is None and self.eval_dataset is None:
346
+ raise ValueError("Trainer: evaluation requires an eval_dataset.")
347
+
348
+ # If we have persistent workers, don't do a fork bomb especially as eval datasets
349
+ # don't change during training
350
+ dataloader_key = eval_dataset if isinstance(eval_dataset, str) else "eval"
351
+ if (
352
+ hasattr(self, "_eval_dataloaders")
353
+ and dataloader_key in self._eval_dataloaders
354
+ and self.args.dataloader_persistent_workers
355
+ ):
356
+ return self.accelerator.prepare(self._eval_dataloaders[dataloader_key])
357
+
358
+ eval_dataset = (
359
+ self.eval_dataset[eval_dataset]
360
+ if isinstance(eval_dataset, str)
361
+ else eval_dataset if eval_dataset is not None else self.eval_dataset
362
+ )
363
+
364
+ data_collator = self.data_collator
365
+
366
+ if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
367
+ eval_dataset = self._remove_unused_columns(
368
+ eval_dataset, description="evaluation"
369
+ )
370
+ else:
371
+ data_collator = self._get_collator_with_removed_columns(
372
+ data_collator, description="evaluation"
373
+ )
374
+
375
+ dataloader_params = {
376
+ "batch_size": self.args.eval_batch_size,
377
+ "collate_fn": data_collator,
378
+ "num_workers": self.args.dataloader_num_workers,
379
+ "pin_memory": self.args.dataloader_pin_memory,
380
+ "persistent_workers": self.args.dataloader_persistent_workers,
381
+ }
382
+
383
+ if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
384
+ dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset)
385
+ dataloader_params["drop_last"] = self.args.dataloader_drop_last
386
+ dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
387
+
388
+ # accelerator.free_memory() will destroy the references, so
389
+ # we need to store the non-prepared version
390
+ eval_dataloader = DataLoader(eval_dataset, **dataloader_params)
391
+ if self.args.dataloader_persistent_workers:
392
+ if hasattr(self, "_eval_dataloaders"):
393
+ self._eval_dataloaders[dataloader_key] = eval_dataloader
394
+ else:
395
+ self._eval_dataloaders = {dataloader_key: eval_dataloader}
396
+
397
+ return self.accelerator.prepare(eval_dataloader)
398
+
399
+ def create_optimizer(self):
400
+ """
401
+ Setup the optimizer.
402
+ We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
403
+ Trainer's init through `optimizers`, or subclass and override this method in a subclass.
404
+ """
405
+ if is_sagemaker_mp_enabled():
406
+ return super().create_optimizer()
407
+
408
+ opt_model = self.model
409
+
410
+ if self.optimizer is None:
411
+ decay_parameters = self.get_decay_parameter_names(opt_model)
412
+ decay_parameters = [name for name in decay_parameters if "bias" not in name]
413
+ lr_mapper = {}
414
+ visual_parameters = []
415
+ merger_parameters = []
416
+ rm_head_parameters = []
417
+
418
+ if self.args.vision_lr is not None:
419
+ lr_mapper["visual"] = self.args.vision_lr
420
+ visual_parameters = [
421
+ name
422
+ for name, _ in opt_model.named_parameters()
423
+ if "visual" in name and "merger" not in name
424
+ ]
425
+ if self.args.merger_lr is not None:
426
+ lr_mapper["merger"] = self.args.merger_lr
427
+ merger_parameters = [
428
+ name for name, _ in opt_model.named_parameters() if "merger" in name
429
+ ]
430
+
431
+ if self.args.rm_head_lr is not None:
432
+ lr_mapper["rm_head"] = self.args.rm_head_lr
433
+ rm_head_parameters = [
434
+ name for name, _ in opt_model.named_parameters() if "rm_head" in name
435
+ ]
436
+
437
+ if len(lr_mapper) > 0:
438
+ special_lr_parameters = merger_parameters + visual_parameters + rm_head_parameters
439
+
440
+ optimizer_grouped_parameters = [
441
+ {
442
+ "params": [
443
+ p
444
+ for n, p in opt_model.named_parameters()
445
+ if (
446
+ n in decay_parameters
447
+ and n not in special_lr_parameters
448
+ and p.requires_grad
449
+ )
450
+ ],
451
+ "weight_decay": self.args.weight_decay,
452
+ },
453
+ {
454
+ "params": [
455
+ p
456
+ for n, p in opt_model.named_parameters()
457
+ if (
458
+ n not in decay_parameters
459
+ and n not in special_lr_parameters
460
+ and p.requires_grad
461
+ )
462
+ ],
463
+ "weight_decay": 0.0,
464
+ },
465
+ ]
466
+
467
+ if visual_parameters:
468
+ optimizer_grouped_parameters.extend(
469
+ [
470
+ {
471
+ "params": [
472
+ p
473
+ for n, p in opt_model.named_parameters()
474
+ if (
475
+ n in decay_parameters
476
+ and n in visual_parameters
477
+ and p.requires_grad
478
+ )
479
+ ],
480
+ "weight_decay": self.args.weight_decay,
481
+ "lr": self.args.vision_lr,
482
+ },
483
+ {
484
+ "params": [
485
+ p
486
+ for n, p in opt_model.named_parameters()
487
+ if (
488
+ n not in decay_parameters
489
+ and n in visual_parameters
490
+ and p.requires_grad
491
+ )
492
+ ],
493
+ "weight_decay": 0.0,
494
+ "lr": self.args.vision_lr,
495
+ },
496
+ ]
497
+ )
498
+
499
+ if merger_parameters:
500
+ optimizer_grouped_parameters.extend(
501
+ [
502
+ {
503
+ "params": [
504
+ p
505
+ for n, p in opt_model.named_parameters()
506
+ if (
507
+ n in decay_parameters
508
+ and n in merger_parameters
509
+ and p.requires_grad
510
+ )
511
+ ],
512
+ "weight_decay": self.args.weight_decay,
513
+ "lr": self.args.merger_lr,
514
+ },
515
+ {
516
+ "params": [
517
+ p
518
+ for n, p in opt_model.named_parameters()
519
+ if (
520
+ n not in decay_parameters
521
+ and n in merger_parameters
522
+ and p.requires_grad
523
+ )
524
+ ],
525
+ "weight_decay": 0.0,
526
+ "lr": self.args.merger_lr,
527
+ },
528
+ ]
529
+ )
530
+
531
+ if rm_head_parameters:
532
+ optimizer_grouped_parameters.extend(
533
+ [
534
+ {
535
+ "params": [
536
+ p
537
+ for n, p in opt_model.named_parameters()
538
+ if (
539
+ n in decay_parameters
540
+ and n in rm_head_parameters
541
+ and p.requires_grad
542
+ )
543
+ ],
544
+ "weight_decay": self.args.weight_decay,
545
+ "lr": self.args.rm_head_lr,
546
+ },
547
+ {
548
+ "params": [
549
+ p
550
+ for n, p in opt_model.named_parameters()
551
+ if (
552
+ n not in decay_parameters
553
+ and n in rm_head_parameters
554
+ and p.requires_grad
555
+ )
556
+ ],
557
+ "weight_decay": 0.0,
558
+ "lr": self.args.rm_head_lr,
559
+ },
560
+ ]
561
+ )
562
+
563
+ else:
564
+ optimizer_grouped_parameters = [
565
+ {
566
+ "params": [
567
+ p
568
+ for n, p in opt_model.named_parameters()
569
+ if (n in decay_parameters and p.requires_grad)
570
+ ],
571
+ "weight_decay": self.args.weight_decay,
572
+ },
573
+ {
574
+ "params": [
575
+ p
576
+ for n, p in opt_model.named_parameters()
577
+ if (n not in decay_parameters and p.requires_grad)
578
+ ],
579
+ "weight_decay": 0.0,
580
+ },
581
+ ]
582
+
583
+ if self.model.special_token_ids:
584
+ special_token_embeddings = opt_model.get_input_embeddings().weight
585
+
586
+ special_token_embeddings.requires_grad = True
587
+
588
+ optimizer_grouped_parameters.extend(
589
+ [
590
+ {
591
+ # "params": [p for n, p in opt_model.get_input_embeddings().named_parameters() if (p.requires_grad)],
592
+ "params": [special_token_embeddings],
593
+ "lr": self.args.special_token_lr,
594
+ "weight_decay": 0.0,
595
+ },
596
+ ]
597
+ )
598
+
599
+ optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(
600
+ self.args, opt_model
601
+ )
602
+
603
+ self.optimizer = optimizer_cls(
604
+ optimizer_grouped_parameters, **optimizer_kwargs
605
+ )
606
+
607
+ return self.optimizer
608
+
609
+ def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
610
+ rewards_A = model(return_dict=True, **inputs["batch_1"])["logits"]
611
+ rewards_B = model(return_dict=True, **inputs["batch_2"])["logits"]
612
+
613
+ # Log to TensorBoard for visualization
614
+ if (hasattr(self.state, 'global_step') and
615
+ self.state.global_step % self.visualization_steps == 0 and
616
+ self.state.global_step > 0):
617
+ # Pass the original inputs which should contain the text prompts
618
+ self._log_training_visualization(inputs, rewards_A, rewards_B)
619
+
620
+ # calculate loss, optionally modulate with margin
621
+ # get chosen and rejected rewards from the chosen label
622
+ (
623
+ rewards_chosen,
624
+ rewards_rejected,
625
+ nontied_mask,
626
+ ) = _convert_A_B_to_chosen_rejected(
627
+ rewards_A,
628
+ rewards_B,
629
+ tied_threshold=self.tied_threshold,
630
+ choice_dist=inputs["choice_dist"],
631
+ )
632
+
633
+ loss_dict = {}
634
+
635
+ if self.loss_type == "bt":
636
+ # Bradley-Terry model
637
+ loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected)
638
+ out_mask = nontied_mask
639
+ loss = loss * out_mask
640
+ loss = loss.mean()
641
+ elif self.loss_type == "likelihood_displacement":
642
+ # Bradley-Terry model
643
+ loss = -nn.functional.logsigmoid(rewards_chosen - self.loss_hyperparameters['tau'] * rewards_rejected)
644
+ out_mask = nontied_mask
645
+ loss = loss * out_mask
646
+ loss = loss.mean()
647
+
648
+ elif self.loss_type == "constant_margin":
649
+ # Bradley-Terry model with constant margin
650
+ loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - 0.57)
651
+ out_mask = nontied_mask
652
+ loss = loss * out_mask
653
+ loss = loss.mean()
654
+ elif self.loss_type == "btt":
655
+ # Bradley-Terry-With-Ties model
656
+ k = 5.0
657
+ log_k = math.log(k)
658
+ log_k2_sub_1 = math.log(k**2 - 1)
659
+ bt_loss = -nn.functional.logsigmoid(
660
+ rewards_chosen - rewards_rejected - log_k
661
+ )
662
+ same_loss = (
663
+ -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - log_k)
664
+ - nn.functional.logsigmoid(rewards_rejected - rewards_chosen - log_k)
665
+ - log_k2_sub_1
666
+ )
667
+ loss = bt_loss * nontied_mask.float() + same_loss * (
668
+ 1 - nontied_mask.float()
669
+ )
670
+ out_mask = torch.ones_like(nontied_mask, dtype=torch.float32).to(
671
+ rewards_A.device
672
+ ) # [B, 1]
673
+ loss = loss * out_mask
674
+
675
+ loss = loss.mean()
676
+ elif self.loss_type == "hpsv2":
677
+ device = rewards_A.device
678
+ rewards = torch.nn.functional.softmax(
679
+ torch.cat([rewards_A, rewards_B], dim=-1), dim=-1
680
+ )
681
+ text_0_logits, text_1_logits = rewards[:, 0], rewards[:, 1]
682
+ label_0, label_1 = torch.ones_like(text_0_logits), torch.zeros_like(
683
+ text_0_logits
684
+ )
685
+
686
+ text_logits = torch.stack([text_0_logits, text_1_logits], dim=-1)
687
+ text_0_labels = torch.zeros(
688
+ text_logits.shape[0], device=device, dtype=torch.long
689
+ )
690
+ text_1_labels = text_0_labels + 1
691
+
692
+ text_0_loss = torch.nn.functional.cross_entropy(
693
+ text_logits, text_0_labels, reduction="none"
694
+ )
695
+ text_1_loss = torch.nn.functional.cross_entropy(
696
+ text_logits, text_1_labels, reduction="none"
697
+ )
698
+
699
+ loss = label_0 * text_0_loss + label_1 * text_1_loss
700
+
701
+ # absolute_example_weight = 1 / num_per_prompt
702
+ # denominator = absolute_example_weight.sum()
703
+ # weight_per_example = absolute_example_weight / denominator
704
+ # text_loss *= weight_per_example
705
+ loss = loss.sum()
706
+ elif self.loss_type == "uncertainty":
707
+ batch_size = rewards_A.shape[0]
708
+ mean_chosen = rewards_A[:, 0]
709
+ mean_rejected = rewards_B[:, 0]
710
+ sigma_chosen = torch.exp(rewards_A[:, 1])
711
+ sigma_rejected = torch.exp(rewards_B[:, 1])
712
+
713
+ mean_z = mean_chosen - mean_rejected
714
+ sigma_z = torch.sqrt(sigma_chosen**2 + sigma_rejected**2)
715
+
716
+ z_samples = torch.randn(batch_size, 1000).to(sigma_z.device).to(
717
+ torch.float16
718
+ ) * sigma_z.unsqueeze(1).repeat(1, 1000) + mean_z.unsqueeze(1).repeat(
719
+ 1, 1000
720
+ )
721
+ loss = -torch.nn.functional.logsigmoid(z_samples).mean()
722
+ else:
723
+ raise NotImplementedError(f"Loss type {self.loss_type} not implemented.")
724
+
725
+ loss_dict.update({"loss": loss.item()})
726
+
727
+ if return_outputs:
728
+ ## return rewards_A/B instead of chosen/rejected
729
+ ## easier to calculate metrics for multi-attribute
730
+ return loss, {
731
+ "rewards_A": rewards_A,
732
+ "rewards_B": rewards_B,
733
+ }
734
+ return loss
735
+
736
+ def prediction_step(
737
+ self,
738
+ model,
739
+ inputs,
740
+ prediction_loss_only,
741
+ ignore_keys=None,
742
+ ):
743
+ model.eval()
744
+ inputs = self._prepare_inputs(inputs)
745
+ if ignore_keys is None:
746
+ if hasattr(self.model, "config"):
747
+ ignore_keys = getattr(
748
+ self.model.config, "keys_to_ignore_at_inference", []
749
+ )
750
+ else:
751
+ ignore_keys = []
752
+
753
+ with torch.no_grad():
754
+ loss, logits_dict = self.compute_loss(model, inputs, return_outputs=True)
755
+
756
+ if prediction_loss_only:
757
+ return (loss, None, None)
758
+ loss = loss.detach()
759
+ logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys)
760
+ logits = nested_detach(logits)
761
+ if self.loss_type != "uncertainty":
762
+ logits = torch.cat(logits, dim=1) # [B, 2]
763
+ else:
764
+ logits = torch.cat([p[:, [0]] for p in logits], dim=1)
765
+
766
+ labels = torch.ones((logits.shape[0], 1)).to(logits.device)
767
+
768
+ return loss, logits, labels
769
+
770
+ def _log_training_visualization(self, inputs, rewards_A, rewards_B):
771
+ """Log training samples and predictions to TensorBoard"""
772
+ try:
773
+ # Get tensorboard writer from trainer
774
+ writer = None
775
+ if hasattr(self, 'log_metrics'):
776
+ # Try to get the writer from the logger
777
+ if hasattr(self.args, 'report_to') and 'tensorboard' in self.args.report_to:
778
+ from torch.utils.tensorboard import SummaryWriter
779
+ if not hasattr(self, '_tb_writer'):
780
+ self._tb_writer = SummaryWriter(log_dir=self.args.logging_dir)
781
+ writer = self._tb_writer
782
+
783
+ if writer is None:
784
+ return
785
+
786
+ step = self.state.global_step
787
+ batch_size = min(len(rewards_A), self.max_viz_samples)
788
+
789
+ # Log scalar metrics
790
+ for i in range(batch_size):
791
+ score_A = rewards_A[i].float().detach().cpu().numpy()
792
+ score_B = rewards_B[i].float().detach().cpu().numpy()
793
+
794
+ # Convert to float for logging
795
+ score_A_val = float(score_A.mean()) if score_A.ndim > 0 else float(score_A)
796
+ score_B_val = float(score_B.mean()) if score_B.ndim > 0 else float(score_B)
797
+ score_diff = score_A_val - score_B_val
798
+
799
+ writer.add_scalar(f'train_viz/sample_{i}/score_A', score_A_val, step)
800
+ writer.add_scalar(f'train_viz/sample_{i}/score_B', score_B_val, step)
801
+ writer.add_scalar(f'train_viz/sample_{i}/score_diff', score_diff, step)
802
+
803
+ try:
804
+ # Get image data from inputs
805
+ image_A = inputs['image_1'][i] if 'image_1' in inputs else None
806
+ image_B = inputs['image_2'][i] if 'image_2' in inputs else None
807
+
808
+ # Get prompt text from the original batch (now properly stored)
809
+ prompt_A = inputs.get('text_1', ['Unknown prompt'])[i] if 'text_1' in inputs else 'Unknown prompt'
810
+
811
+ fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(12, 8))
812
+ fig.text(0.05, 0.05, f'Prompt:\n{prompt_A[:200]}{"..." if len(prompt_A) > 200 else ""}',
813
+ ha='left', va='bottom', fontsize=8, wrap=True,
814
+ bbox=dict(boxstyle="round,pad=0.3", facecolor="lightblue", alpha=0.7))
815
+ img_A_np = np.array(image_A)
816
+ if img_A_np.ndim == 3 and img_A_np.shape[0] == 3: # CHW format
817
+ img_A_np = np.transpose(img_A_np, (1, 2, 0))
818
+ img_A_np = np.clip(img_A_np, 0, 1) # Ensure values are in [0,1]
819
+ axes[0].imshow(img_A_np)
820
+ axes[0].set_title(f'Image A - Score: {score_A_val:.3f}')
821
+ axes[0].axis('off')
822
+
823
+ img_B_np = np.array(image_B)
824
+ if img_B_np.ndim == 3 and img_B_np.shape[0] == 3: # CHW format
825
+ img_B_np = np.transpose(img_B_np, (1, 2, 0))
826
+ img_B_np = np.clip(img_B_np, 0, 1) # Ensure values are in [0,1]
827
+ axes[1].imshow(img_B_np)
828
+
829
+ axes[1].set_title(f'Image B - Score: {score_B_val:.3f}')
830
+ axes[1].axis('off')
831
+
832
+ # Add prediction info
833
+ winner = "A" if score_diff > 0 else "B"
834
+ plt.suptitle(f'Step {step} - Sample {i} | Predicted Winner: Image {winner} | Diff: {score_diff:.3f}', fontsize=14)
835
+ plt.tight_layout()
836
+
837
+ # Log figure to tensorboard
838
+ writer.add_figure(f'train_viz/sample_{i}_comparison', fig, step)
839
+ plt.close(fig)
840
+ except Exception as viz_error:
841
+ print(f"Warning: Could not extract images for visualization: {viz_error}")
842
+ continue
843
+
844
+ # Log aggregate statistics
845
+ all_scores_A = rewards_A.float().detach().cpu().numpy()
846
+ all_scores_B = rewards_B.float().detach().cpu().numpy()
847
+
848
+ writer.add_histogram('train_viz/all_scores_A', all_scores_A, step)
849
+ writer.add_histogram('train_viz/all_scores_B', all_scores_B, step)
850
+ writer.add_scalar('train_viz/mean_score_A', float(all_scores_A.mean()), step)
851
+ writer.add_scalar('train_viz/mean_score_B', float(all_scores_B.mean()), step)
852
+ writer.add_scalar('train_viz/mean_score_diff', float((all_scores_A - all_scores_B).mean()), step)
853
+
854
+ except Exception as e:
855
+ print(f"Error in training visualization: {e}")
856
+
857
+ def _save_checkpoint(self, model, trial, metrics=None):
858
+
859
+ if isinstance(self.model, PeftModel):
860
+ checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
861
+
862
+ if self.hp_search_backend is None and trial is None:
863
+ self.store_flos()
864
+
865
+ run_dir = self._get_output_dir(trial=trial)
866
+ output_dir = os.path.join(run_dir, checkpoint_folder)
867
+ os.makedirs(output_dir, exist_ok=True)
868
+
869
+ # TODO: Just Temp
870
+ self.save_model(output_dir, _internal_call=True)
871
+ # pdb.set_trace()
872
+
873
+ if not self.args.save_full_model:
874
+ non_lora_weights = get_peft_state_non_lora_maybe_zero_3(
875
+ self.model.named_parameters(), require_grad_only=True
876
+ )
877
+ torch.save(
878
+ non_lora_weights,
879
+ os.path.join(output_dir, "non_lora_state_dict.pth"),
880
+ )
881
+ # safetensors.torch.save(non_lora_weights, os.path.join(output_dir, "non_lora_model.safetensors"))
882
+
883
+ if not self.args.save_only_model:
884
+ # Save optimizer and scheduler
885
+ self._save_optimizer_and_scheduler(output_dir)
886
+ # Save RNG state
887
+ self._save_rng_state(output_dir)
888
+
889
+ else:
890
+ super(RewardTrainer, self)._save_checkpoint(model, trial, metrics)
891
+
892
+ def _save(self, output_dir: Optional[str] = None, state_dict=None):
893
+ # If we are executing this function, we are the process zero, so we don't check for that.
894
+ output_dir = output_dir if output_dir is not None else self.args.output_dir
895
+ os.makedirs(output_dir, exist_ok=True)
896
+ logger.info(f"Saving model checkpoint to {output_dir}")
897
+ # pdb.set_trace()
898
+
899
+ supported_classes = (
900
+ (PreTrainedModel,)
901
+ if not is_peft_available()
902
+ else (PreTrainedModel, PeftModel)
903
+ )
904
+ # Save a trained model and configuration using `save_pretrained()`.
905
+ # They can then be reloaded using `from_pretrained()`
906
+ if not isinstance(self.model, supported_classes):
907
+ if state_dict is None:
908
+ state_dict = self.model.state_dict()
909
+
910
+ if isinstance(self.accelerator.unwrap_model(self.model), supported_classes):
911
+ self.accelerator.unwrap_model(self.model).save_pretrained(
912
+ output_dir,
913
+ state_dict=state_dict,
914
+ safe_serialization=self.args.save_safetensors,
915
+ )
916
+ else:
917
+ logger.info(
918
+ "Trainer.model is not a `PreTrainedModel`, only saving its state dict."
919
+ )
920
+ if self.args.save_safetensors:
921
+ safetensors.torch.save_file(
922
+ state_dict,
923
+ os.path.join(output_dir, SAFE_WEIGHTS_NAME),
924
+ metadata={"format": "pt"},
925
+ )
926
+ else:
927
+ torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
928
+ else:
929
+ if not self.args.save_full_model:
930
+ state_dict = {k: v for k, v in state_dict.items() if "wte" not in k}
931
+ self.model.save_pretrained(
932
+ output_dir,
933
+ state_dict=state_dict,
934
+ safe_serialization=self.args.save_safetensors,
935
+ )
936
+ else:
937
+ torch.save(state_dict, os.path.join(output_dir, "model.pth"))
938
+
939
+ if self.tokenizer is not None:
940
+ os.makedirs(os.path.join(output_dir, "tokenizer"), exist_ok=True)
941
+ self.tokenizer.save_pretrained(os.path.join(output_dir, "tokenizer"))
942
+
943
+ # Good practice: save your training arguments together with the trained model
944
+ torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
945
+ # pdb.set_trace()
946
+
947
+
948
+ def compute_multi_attr_accuracy(eval_pred, metainfo_idxs=None) -> Dict[str, float]:
949
+ predictions, labels = eval_pred
950
+ metrics = {}
951
+
952
+ pred_curr = predictions
953
+ label_curr = labels.squeeze(1)
954
+ total_count = np.sum(label_curr != 0)
955
+
956
+ rewards_chosen = pred_curr[:, 0]
957
+ rewards_rejected = pred_curr[:, 1]
958
+
959
+ rewards_chosen_avg = np.sum(rewards_chosen) / total_count
960
+ rewards_rejected_avg = np.sum(rewards_rejected) / total_count
961
+
962
+ accuracy = np.sum(rewards_chosen > rewards_rejected) / total_count
963
+
964
+ metrics.update(
965
+ {
966
+ f"Acc": accuracy,
967
+ f"R_chosen_avg": rewards_chosen_avg,
968
+ f"R_rejected_avg": rewards_rejected_avg,
969
+ }
970
+ )
971
+ return metrics
hpsv3/model/test_differentiable.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ from collections.abc import Mapping
4
+ import torch
5
+ import numpy as np
6
+ from PIL import Image
7
+ import huggingface_hub
8
+ from hpsv3.dataset.utils import process_vision_info
9
+ from hpsv3.dataset.data_collator_qwen import prompt_with_special_token, prompt_without_special_token, INSTRUCTION
10
+ from hpsv3.utils.parser import ModelConfig, PEFTLoraConfig, TrainingConfig, DataConfig, parse_args_with_yaml
11
+ from hpsv3.train import create_model_and_processor
12
+ from pathlib import Path
13
+
14
+ _MODEL_CONFIG_PATH = Path(__file__).parent / f"config/"
15
+
16
+ class HPSv3RewardInferencer():
17
+ def __init__(self, config_path=None, checkpoint_path=None, device='cuda', differentiable=False):
18
+ if config_path is None:
19
+ config_path = os.path.join(_MODEL_CONFIG_PATH, 'HPSv3_7B.yaml')
20
+
21
+ if checkpoint_path is None:
22
+ checkpoint_path = os.path.join(huggingface_hub.hf_hub_download("xilanhua12138/HPSv3"), 'model.pth')
23
+
24
+ (data_config, training_args, model_config, peft_lora_config), config_path = (
25
+ parse_args_with_yaml(
26
+ (DataConfig, TrainingConfig, ModelConfig, PEFTLoraConfig), config_path, is_train=False
27
+ )
28
+ )
29
+ training_args.output_dir = os.path.join(
30
+ training_args.output_dir, config_path.split("/")[-1].split(".")[0]
31
+ )
32
+ model, processor, peft_config = create_model_and_processor(
33
+ model_config=model_config,
34
+ peft_lora_config=peft_lora_config,
35
+ training_args=training_args,
36
+ differentiable=differentiable,
37
+ )
38
+
39
+ self.device = device
40
+ self.use_special_tokens = model_config.use_special_tokens
41
+
42
+ state_dict = torch.load(checkpoint_path , map_location="cpu")
43
+ if "model" in state_dict:
44
+ state_dict = state_dict["model"]
45
+ model.load_state_dict(state_dict, strict=False)
46
+ model.eval()
47
+
48
+ self.model = model
49
+ self.processor = processor
50
+
51
+ self.model.to(self.device)
52
+ self.data_config = data_config
53
+
54
+ def _pad_sequence(self, sequences, attention_mask, max_len, padding_side='right'):
55
+ """
56
+ Pad the sequences to the maximum length.
57
+ """
58
+ assert padding_side in ['right', 'left']
59
+ if sequences.shape[1] >= max_len:
60
+ return sequences, attention_mask
61
+
62
+ pad_len = max_len - sequences.shape[1]
63
+ padding = (0, pad_len) if padding_side == 'right' else (pad_len, 0)
64
+
65
+ sequences_padded = torch.nn.functional.pad(sequences, padding, 'constant', self.processor.tokenizer.pad_token_id)
66
+ attention_mask_padded = torch.nn.functional.pad(attention_mask, padding, 'constant', 0)
67
+
68
+ return sequences_padded, attention_mask_padded
69
+
70
+ def _prepare_input(self, data):
71
+ """
72
+ Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and
73
+ handling potential state.
74
+ """
75
+ if isinstance(data, Mapping):
76
+ return type(data)({k: self._prepare_input(v) for k, v in data.items()})
77
+ elif isinstance(data, (tuple, list)):
78
+ return type(data)(self._prepare_input(v) for v in data)
79
+ elif isinstance(data, torch.Tensor):
80
+ kwargs = {"device": self.device}
81
+ ## TODO: Maybe need to add dtype
82
+ # if self.is_deepspeed_enabled and (torch.is_floating_point(data) or torch.is_complex(data)):
83
+ # # NLP models inputs are int/uint and those get adjusted to the right dtype of the
84
+ # # embedding. Other models such as wav2vec2's inputs are already float and thus
85
+ # # may need special handling to match the dtypes of the model
86
+ # kwargs.update({"dtype": self.accelerator.state.deepspeed_plugin.hf_ds_config.dtype()})
87
+ return data.to(**kwargs)
88
+ return data
89
+
90
+ def _prepare_inputs(self, inputs):
91
+ """
92
+ Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and
93
+ handling potential state.
94
+ """
95
+ inputs = self._prepare_input(inputs)
96
+ if len(inputs) == 0:
97
+ raise ValueError
98
+ return inputs
99
+
100
+ def prepare_batch(self, image_paths, prompts):
101
+ max_pixels = 256 * 28 * 28
102
+ min_pixels = 256 * 28 * 28
103
+ message_list = []
104
+ for text, image in zip(prompts, image_paths):
105
+ out_message = [
106
+ {
107
+ "role": "user",
108
+ "content": [
109
+ {
110
+ "type": "image",
111
+ "image": image,
112
+ "min_pixels": max_pixels,
113
+ "max_pixels": max_pixels,
114
+ },
115
+ {
116
+ "type": "text",
117
+ "text": (
118
+ INSTRUCTION.format(text_prompt=text)
119
+ + prompt_with_special_token
120
+ if self.use_special_tokens
121
+ else prompt_without_special_token
122
+ ),
123
+ },
124
+ ],
125
+ }
126
+ ]
127
+
128
+ message_list.append(out_message)
129
+
130
+ image_inputs, _ = process_vision_info(message_list)
131
+
132
+ batch = self.processor(
133
+ text=self.processor.apply_chat_template(message_list, tokenize=False, add_generation_prompt=True),
134
+ images=image_inputs,
135
+ padding=True,
136
+ return_tensors="pt",
137
+ videos_kwargs={"do_rescale": True},
138
+ )
139
+ batch = self._prepare_inputs(batch)
140
+ return batch
141
+
142
+ def reward(self, image_paths, prompts):
143
+
144
+ batch = self.prepare_batch(image_paths, prompts)
145
+ rewards = self.model(
146
+ return_dict=True,
147
+ **batch
148
+ )["logits"]
149
+
150
+ return rewards
151
+
152
+
153
+ if __name__ == "__main__":
154
+ config_path = '/preflab/shuiyunhao/tasks/HPSv3_official/hpsv3/config/HPSv3_7B.yaml'
155
+ checkpoint_path = '/preflab/shuiyunhao/tasks/HPSv3_official/checkpoints/HPSv3_7B/model.pth'
156
+ device = 'cuda'
157
+ dtype = torch.bfloat16
158
+ inferencer = HPSv3RewardInferencer(config_path, checkpoint_path, differentiable=True, device=device)
159
+
160
+ images = [
161
+ torch.from_numpy(np.array(Image.open("assets/example1.png"))),
162
+ torch.from_numpy(np.array(Image.open("assets/example2.png")))
163
+ ]
164
+ prompts = [
165
+ "cute chibi anime cartoon fox, smiling wagging tail with a small cartoon heart above sticker",
166
+ "cute chibi anime cartoon fox, smiling wagging tail with a small cartoon heart above sticker"
167
+ ]
168
+ rewards = inferencer.reward(images, prompts)
169
+ print(rewards[0][0].item()) # miu and sigma. we select miu as the final output
170
+ print(rewards[1][0].item())
171
+
172
+
173
+ loss = rewards[0][0]
174
+ print(f"Loss value: {loss.item()}")
175
+ print(f"Loss requires_grad: {loss.requires_grad}")
176
+
177
+ if loss.requires_grad:
178
+ loss.backward(retain_graph=True)
179
+
180
+ has_grad = False
181
+ for name, param in inferencer.model.named_parameters():
182
+ if param.grad is not None:
183
+ has_grad = True
184
+ print(f"参数 {name} 有梯度: {param.grad.norm().item():.6f}")
185
+ break
186
+
187
+ if has_grad:
188
+ print("has grad")
189
+ else:
190
+ print("NO GRAD!!")
191
+ else:
192
+ print("Final loss does not require gradient computation.")
193
+
194
+ # Compare
195
+ img_pil = Image.open("assets/example1.png").convert('RGB')
196
+ non_diff_result = inferencer.processor.preprocess(
197
+ images=[img_pil],
198
+ return_tensors="pt"
199
+ )
200
+
201
+ img_tensor = torch.from_numpy(np.array(img_pil)).float().permute(2, 0, 1) / 255.0
202
+ diff_result = inferencer.processor.preprocess_tensor(img_tensor)
203
+
204
+ if non_diff_result['pixel_values'].shape == diff_result['pixel_values'].shape:
205
+ diff = torch.abs(non_diff_result['pixel_values'] - diff_result['pixel_values']).mean()
206
+ if diff.item() < 0.01:
207
+ print("Right")
208
+ else:
209
+ print("Different outputs")
210
+ else:
211
+ print("Shape mismatch between non-differentiable and differentiable outputs.")
212
+
hpsv3/train.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import fire
4
+ from dataclasses import asdict
5
+ from functools import partial
6
+ import torch
7
+ import torch.distributed as dist
8
+ from hpsv3.model.qwen2vl_trainer import (
9
+ Qwen2VLRewardModelBT,
10
+ VLMRewardTrainer,
11
+ compute_multi_attr_accuracy,
12
+ PartialEmbeddingUpdateCallback,
13
+ )
14
+ from hpsv3.dataset.pairwise_dataset import PairwiseOriginalDataset
15
+ from hpsv3.dataset.data_collator_qwen import QWen2VLDataCollator
16
+ from hpsv3.utils.parser import ModelConfig, PEFTLoraConfig, TrainingConfig, DataConfig
17
+ from hpsv3.utils.training_utils import load_model_from_checkpoint, find_target_linear_names
18
+ from hpsv3.utils.parser import parse_args_with_yaml
19
+ from transformers import AutoProcessor
20
+ from peft import LoraConfig, get_peft_model
21
+ from trl import get_kbit_device_map, get_quantization_config
22
+ from hpsv3.model.differentiable_image_processor import Qwen2VLImageProcessor
23
+ try:
24
+ import flash_attn
25
+ except ImportError:
26
+ flash_attn = None
27
+ print("Flash Attention is not installed. Falling to SDPA.")
28
+
29
+ def create_model_and_processor(
30
+ model_config,
31
+ peft_lora_config,
32
+ training_args,
33
+ cache_dir=None,
34
+ differentiable=False,
35
+ ):
36
+ # create model
37
+ torch_dtype = (
38
+ model_config.torch_dtype
39
+ if model_config.torch_dtype in ["auto", None]
40
+ else getattr(torch, model_config.torch_dtype)
41
+ )
42
+ quantization_config = get_quantization_config(model_config)
43
+ model_kwargs = dict(
44
+ revision=model_config.model_revision,
45
+ device_map=get_kbit_device_map() if quantization_config is not None else None,
46
+ quantization_config=quantization_config,
47
+ use_cache=False
48
+ )
49
+
50
+ # create processor and set padding
51
+
52
+ processor = AutoProcessor.from_pretrained(
53
+ model_config.model_name_or_path, padding_side="right", cache_dir=cache_dir
54
+ )
55
+
56
+ if differentiable:
57
+ processor.image_processor = Qwen2VLImageProcessor()
58
+
59
+ special_token_ids = None
60
+ if model_config.use_special_tokens:
61
+ special_tokens = ["<|Reward|>"]
62
+ processor.tokenizer.add_special_tokens(
63
+ {"additional_special_tokens": special_tokens}
64
+ )
65
+ special_token_ids = processor.tokenizer.convert_tokens_to_ids(special_tokens)
66
+
67
+ model = Qwen2VLRewardModelBT.from_pretrained(
68
+ model_config.model_name_or_path,
69
+ output_dim=model_config.output_dim,
70
+ reward_token=model_config.reward_token,
71
+ special_token_ids=special_token_ids,
72
+ torch_dtype=torch_dtype,
73
+ attn_implementation=(
74
+ "flash_attention_2" if not training_args.disable_flash_attn2 and flash_attn is not None else "sdpa"
75
+ ),
76
+ cache_dir=cache_dir,
77
+ rm_head_type=model_config.rm_head_type,
78
+ rm_head_kwargs=model_config.rm_head_kwargs,
79
+ **model_kwargs,
80
+ )
81
+
82
+ if model_config.use_special_tokens:
83
+ model.resize_token_embeddings(len(processor.tokenizer))
84
+
85
+ if training_args.bf16:
86
+ model.to(torch.bfloat16)
87
+ if training_args.fp16:
88
+ model.to(torch.float16)
89
+
90
+ model.rm_head.to(torch.float32)
91
+
92
+ # create lora and peft model
93
+ if peft_lora_config.lora_enable:
94
+ target_modules = find_target_linear_names(
95
+ model,
96
+ num_lora_modules=peft_lora_config.num_lora_modules,
97
+ lora_namespan_exclude=peft_lora_config.lora_namespan_exclude,
98
+ )
99
+ peft_config = LoraConfig(
100
+ target_modules=target_modules,
101
+ r=peft_lora_config.lora_r,
102
+ lora_alpha=peft_lora_config.lora_alpha,
103
+ lora_dropout=peft_lora_config.lora_dropout,
104
+ task_type=peft_lora_config.lora_task_type,
105
+ use_rslora=peft_lora_config.use_rslora,
106
+ bias="none",
107
+ modules_to_save=peft_lora_config.lora_modules_to_save,
108
+ )
109
+ model = get_peft_model(model, peft_config)
110
+ else:
111
+ peft_config = None
112
+
113
+ model.config.tokenizer_padding_side = processor.tokenizer.padding_side
114
+ model.config.pad_token_id = processor.tokenizer.pad_token_id
115
+
116
+ return model, processor, peft_config
117
+
118
+
119
+ def save_configs_to_json(data_config, training_args, model_config, peft_lora_config):
120
+ """
121
+ Save all configurations to a JSON file.
122
+ """
123
+ config_dict = {
124
+ "data_config": asdict(data_config),
125
+ "training_args": asdict(training_args),
126
+ "model_config": asdict(model_config),
127
+ "peft_lora_config": asdict(peft_lora_config),
128
+ }
129
+ # del information about local device
130
+ del config_dict["training_args"]["local_rank"]
131
+ del config_dict["training_args"]["_n_gpu"]
132
+
133
+ save_path = os.path.join(training_args.output_dir, "model_config.json")
134
+
135
+ os.makedirs(training_args.output_dir, exist_ok=True)
136
+ print(training_args.output_dir)
137
+
138
+ with open(save_path, "w") as f:
139
+ json.dump(config_dict, f, indent=4)
140
+
141
+
142
+ def set_requires_grad(parameters, requires_grad):
143
+ for p in parameters:
144
+ p.requires_grad = requires_grad
145
+
146
+ def train(config, local_rank=0, debug=False):
147
+
148
+ ## ===> Step 1: Parse arguments
149
+ (data_config, training_args, model_config, peft_lora_config), config_path = (
150
+ parse_args_with_yaml(
151
+ (DataConfig, TrainingConfig, ModelConfig, PEFTLoraConfig), config, is_train=True
152
+ )
153
+ )
154
+ training_args.output_dir = os.path.join(
155
+ training_args.output_dir, config.split("/")[-1].split(".")[0]
156
+ )
157
+ training_args.logging_dir = training_args.output_dir
158
+ # check valid (lora config)
159
+ assert not (
160
+ peft_lora_config.lora_enable and model_config.freeze_llm
161
+ ), "When using LoRA, the LLM should not be frozen. If you want to freeze the LLM, please disable LoRA."
162
+ if not peft_lora_config.lora_enable:
163
+ assert (
164
+ not peft_lora_config.vision_lora
165
+ ), "Error: model_config.lora_enable is not enabled, but model_config.vision_lora is enabled."
166
+ else:
167
+ if peft_lora_config.lora_namespan_exclude is None:
168
+ peft_lora_config.lora_namespan_exclude = []
169
+ if not peft_lora_config.vision_lora:
170
+ peft_lora_config.lora_namespan_exclude += ["visual"]
171
+
172
+ ## ===> Step 2: Load model and configure
173
+ model, processor, peft_config = create_model_and_processor(
174
+ model_config=model_config,
175
+ peft_lora_config=peft_lora_config,
176
+ training_args=training_args,
177
+ )
178
+
179
+ ## load model
180
+ if training_args.load_from_pretrained is not None:
181
+ model, checkpoint_step = load_model_from_checkpoint(
182
+ model,
183
+ training_args.load_from_pretrained,
184
+ training_args.load_from_pretrained_step,
185
+ )
186
+ model.train()
187
+
188
+ if peft_lora_config.lora_enable:
189
+ model_to_configure = model.model
190
+ else:
191
+ model_to_configure = model
192
+ # set requires_grad for LLM
193
+ set_requires_grad(
194
+ model_to_configure.model.parameters(), not model_config.freeze_llm
195
+ )
196
+ set_requires_grad(model_to_configure.model.embed_tokens.parameters(), False)
197
+ if not peft_lora_config.vision_lora:
198
+ # set requires_grad for visual encoder and merger
199
+ set_requires_grad(
200
+ model_to_configure.visual.parameters(), not model_config.freeze_vision_tower
201
+ )
202
+ set_requires_grad(
203
+ model_to_configure.visual.merger.parameters(), model_config.tune_merger
204
+ )
205
+
206
+ if model_config.trainable_visual_layers: # This is inverse order to index of model.visual.blocks, set -1 to unfreeze all layers
207
+ assert model_config.trainable_visual_layers <= len(model_to_configure.visual.blocks), "trainable_visual_layers should be less than or equal to the number of visual blocks"
208
+ freeze_layer_num = len(model_to_configure.visual.blocks) - model_config.trainable_visual_layers if model_config.trainable_visual_layers > 0 else 0
209
+ for index, layer in enumerate(model_to_configure.visual.blocks):
210
+ if index < freeze_layer_num:
211
+ set_requires_grad(layer.parameters(), False)
212
+ else:
213
+ set_requires_grad(layer.parameters(), True)
214
+
215
+ # set requires_grad for regression head
216
+ set_requires_grad(model_to_configure.rm_head.parameters(), True)
217
+
218
+ ## ===> Step 3: Load Dataset and configure
219
+ train_dataset = PairwiseOriginalDataset(
220
+ data_config.train_json_list,
221
+ data_config.soft_label,
222
+ data_config.confidence_threshold,
223
+ )
224
+ test_set_dict = {}
225
+ for item in data_config.test_json_list:
226
+ test_set_dict[item[0]] = PairwiseOriginalDataset(
227
+ item[1],
228
+ data_config.soft_label,
229
+ data_config.confidence_threshold,
230
+ )
231
+
232
+ print(f"===> Selected {len(train_dataset)} samples for training.")
233
+ for key, value in test_set_dict.items():
234
+ print(f"===> Selected {len(value)} samples for {key} testing.")
235
+
236
+ num_gpu = int(os.environ.get("WORLD_SIZE", 1))
237
+ data_collator = QWen2VLDataCollator(
238
+ processor,
239
+ max_pixels=data_config.max_pixels,
240
+ min_pixels=data_config.min_pixels,
241
+ with_instruction=data_config.with_instruction,
242
+ use_special_tokens=model_config.use_special_tokens,
243
+ )
244
+ compute_metrics = partial(compute_multi_attr_accuracy)
245
+
246
+ actual_batch_size = (
247
+ training_args.per_device_train_batch_size
248
+ * training_args.gradient_accumulation_steps
249
+ * num_gpu
250
+ )
251
+ total_steps = (
252
+ training_args.num_train_epochs * len(train_dataset) // actual_batch_size
253
+ )
254
+ if training_args.save_epochs is not None:
255
+ training_args.save_steps = round(
256
+ training_args.save_epochs * len(train_dataset) / actual_batch_size
257
+ )
258
+ if training_args.eval_epochs is not None:
259
+ training_args.eval_steps = round(
260
+ training_args.eval_epochs * len(train_dataset) / actual_batch_size
261
+ )
262
+ if training_args.logging_epochs is not None:
263
+ training_args.logging_steps = round(
264
+ training_args.logging_epochs * len(train_dataset) / actual_batch_size
265
+ )
266
+
267
+ if training_args.local_rank == -1 or training_args.local_rank == 0:
268
+ print(f"===> Using {num_gpu} GPUs.")
269
+ print(f"===> Total Batch Size: {actual_batch_size}")
270
+ print(f"===> Training Epochs: {training_args.num_train_epochs}")
271
+ print(f"===> Total Steps: {total_steps}")
272
+ print(f"===> Save Steps: {training_args.save_steps}")
273
+ print(f"===> Eval Steps: {training_args.eval_steps}")
274
+ print(f"===> Logging Steps: {training_args.logging_steps}")
275
+
276
+ ## ===> Step 4: Save configs for re-check
277
+ if training_args.local_rank == -1 or training_args.local_rank == 0:
278
+ save_configs_to_json(data_config, training_args, model_config, peft_lora_config)
279
+
280
+ print(train_dataset)
281
+ ## ===> Step 5: Start Training!
282
+
283
+ special_token_ids = model.special_token_ids
284
+ callbacks = []
285
+ if special_token_ids is not None:
286
+ callbacks.append(PartialEmbeddingUpdateCallback(special_token_ids))
287
+
288
+ trainer = VLMRewardTrainer(
289
+ model=model,
290
+ compute_metrics=compute_metrics,
291
+ data_collator=data_collator,
292
+ args=training_args,
293
+ train_dataset=train_dataset,
294
+ eval_dataset=(test_set_dict if training_args.conduct_eval else None),
295
+ peft_config=peft_config,
296
+ callbacks=callbacks,
297
+ loss_type=model_config.loss_type,
298
+ loss_hyperparameters=model_config.loss_hyperparameters,
299
+ tokenizer=processor.tokenizer,
300
+ tied_threshold=data_config.tied_threshold,
301
+ visualization_steps=training_args.visualization_steps,
302
+ max_viz_samples=training_args.max_viz_samples,
303
+ )
304
+ trainer.train()
305
+
306
+ if training_args.local_rank == -1 or training_args.local_rank == 0:
307
+ model_state_dict = model.state_dict()
308
+ torch.save(
309
+ model_state_dict, os.path.join(training_args.output_dir, "final_model.pth")
310
+ )
311
+ model.config.save_pretrained(training_args.output_dir)
312
+
313
+
314
+ if __name__ == "__main__":
315
+ fire.Fire(train)
hpsv3/utils/parser.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import yaml
3
+ from pathlib import Path
4
+ from typing import Any, Optional, Union, Tuple, List, Literal
5
+ from omegaconf import OmegaConf
6
+ from transformers import HfArgumentParser
7
+ from dataclasses import dataclass, field
8
+ from transformers import TrainingArguments
9
+
10
+ @dataclass
11
+ class DataConfig:
12
+ train_json_list: List[str] = field(default_factory=lambda: ["/path/to/dataset/meta_data.json"])
13
+ val_json_list: List[str] = field(default_factory=lambda: ["/path/to/dataset/meta_data.json"])
14
+ test_json_list: List[str] = field(default_factory=lambda: ["/path/to/dataset/meta_data.json"])
15
+ soft_label: bool = False
16
+ confidence_threshold: Optional[float] = None
17
+ max_pixels: Optional[int] = 256 * 28 * 28 # Default max pixels
18
+ min_pixels: Optional[int] = 256 * 28 * 28
19
+ with_instruction: bool = True
20
+ tied_threshold: Optional[float] = None
21
+
22
+ @dataclass
23
+ class TrainingConfig(TrainingArguments):
24
+ max_grad_norm: Optional[float] = 1.0
25
+ dataset_num_proc: Optional[int] = None
26
+ center_rewards_coefficient: Optional[float] = None
27
+ disable_flash_attn2: bool = field(default=False)
28
+ disable_dropout: bool = field(default=False)
29
+
30
+ vision_lr: Optional[float] = None
31
+ merger_lr: Optional[float] = None
32
+ rm_head_lr: Optional[float] = None
33
+ special_token_lr: Optional[float] = None
34
+
35
+ conduct_eval: Optional[bool] = True
36
+ load_from_pretrained: str = None
37
+ load_from_pretrained_step: int = None
38
+ logging_epochs: Optional[float] = None
39
+ eval_epochs: Optional[float] = None
40
+ save_epochs: Optional[float] = None
41
+ remove_unused_columns: Optional[bool] = False
42
+
43
+ save_full_model: Optional[bool] = False
44
+
45
+ # Visualization parameters
46
+ visualization_steps: Optional[int] = 100
47
+ max_viz_samples: Optional[int] = 4
48
+
49
+ @dataclass
50
+ class PEFTLoraConfig:
51
+ lora_enable: bool = False
52
+ vision_lora: bool = False
53
+ lora_r: int = 16
54
+ lora_alpha: int = 32
55
+ lora_dropout: float = 0.05
56
+ lora_target_modules: Optional[List[str]] = None
57
+ lora_namespan_exclude: Optional[List[str]] = None
58
+ lora_modules_to_save: Optional[List[str]] = None
59
+ lora_task_type: str = "CAUSAL_LM"
60
+ use_rslora: bool = False
61
+ num_lora_modules: int = -1
62
+
63
+ def __post_init__(self):
64
+ if (
65
+ isinstance(self.lora_target_modules, list)
66
+ and len(self.lora_target_modules) == 1
67
+ ):
68
+ self.lora_target_modules = self.lora_target_modules[0]
69
+
70
+ if (
71
+ isinstance(self.lora_namespan_exclude, list)
72
+ and len(self.lora_namespan_exclude) == 1
73
+ ):
74
+ self.lora_namespan_exclude = self.lora_namespan_exclude[0]
75
+
76
+
77
+ @dataclass
78
+ class ModelConfig:
79
+ model_name_or_path: Optional[str] = None
80
+ model_revision: str = "main"
81
+ rm_head_type: str = "default"
82
+ rm_head_kwargs: Optional[dict] = None
83
+ output_dim: int = 1
84
+
85
+ use_special_tokens: bool = False
86
+
87
+ freeze_vision_tower: bool = field(default=False)
88
+ freeze_llm: bool = field(default=False)
89
+ tune_merger: bool = field(default=False)
90
+ trainable_visual_layers: Optional[int] = -1
91
+
92
+ torch_dtype: Optional[Literal["auto", "bfloat16", "float16", "float32"]] = None
93
+ trust_remote_code: bool = False
94
+ attn_implementation: Optional[str] = None
95
+ load_in_8bit: bool = False
96
+ load_in_4bit: bool = False
97
+ bnb_4bit_quant_type: Literal["fp4", "nf4"] = "nf4"
98
+ use_bnb_nested_quant: bool = False
99
+ reward_token: Literal["last", "mean", "special"] = "last"
100
+ loss_type: Literal["bt", "reg", "btt", "margin", "constant_margin", "scaled"] = (
101
+ "regular"
102
+ )
103
+ loss_hyperparameters: dict = field(default_factory=lambda: {})
104
+ checkpoint_path: Optional[str] = None
105
+
106
+ def __post_init__(self):
107
+ if self.load_in_8bit and self.load_in_4bit:
108
+ raise ValueError("You can't use 8 bit and 4 bit precision at the same time")
109
+
110
+ # if isinstance(self.lora_target_modules, list) and len(self.lora_target_modules) == 1:
111
+ # self.lora_target_modules = self.lora_target_modules[0]
112
+
113
+ # if isinstance(self.lora_namespan_exclude, list) and len(self.lora_namespan_exclude) == 1:
114
+ # self.lora_namespan_exclude = self.lora_namespan_exclude[0]
115
+
116
+
117
+ ########## Functions for get trainable modules' parameters ##########
118
+
119
+ def parse_args_with_yaml(
120
+ dataclass_types: Tuple[type, ...],
121
+ config_path: str = None,
122
+ allow_extra_keys: bool = True,
123
+ is_train: bool = True,
124
+ ) -> Tuple[Any, ...]:
125
+ """
126
+ Parse arguments using HfArgumentParser with OmegaConf for YAML support.
127
+
128
+ Args:
129
+ dataclass_types: Tuple of dataclass types for HfArgumentParser
130
+ args: Optional arguments (if None, will read from sys.argv)
131
+ allow_extra_keys: Whether to allow extra keys in config
132
+
133
+ Returns:
134
+ Tuple of parsed dataclass instances
135
+ """
136
+ # Read arguments from command line or provided args
137
+ # Load YAML config and merge with command line overrides
138
+ args = OmegaConf.to_container(OmegaConf.load(config_path))
139
+ if not is_train:
140
+ args.pop('deepspeed', None)
141
+
142
+ # Parse with HfArgumentParser
143
+ parser = HfArgumentParser(dataclass_types)
144
+ return parser.parse_dict(args, allow_extra_keys=allow_extra_keys), config_path
145
+
146
+
147
+ if __name__ == "__main__":
148
+ data_config, training_args, model_config, peft_lora_config = parse_args_with_yaml(
149
+ (DataConfig, TrainingConfig, ModelConfig, PEFTLoraConfig)
150
+ )
hpsv3/utils/training_utils.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import glob
4
+ import safetensors
5
+
6
+
7
+ def maybe_zero_3(param, ignore_status=False, name=None):
8
+ from deepspeed import zero
9
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
10
+
11
+ if hasattr(param, "ds_id"):
12
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
13
+ if not ignore_status:
14
+ print(
15
+ f"Parameter {name} is not available in ZeRO-3, please check the ZeRO-3 status."
16
+ )
17
+ with zero.GatheredParameters([param]):
18
+ param = param.data.detach().cpu().clone()
19
+ else:
20
+ param = param.detach().cpu().clone()
21
+ return param
22
+
23
+
24
+ # Borrowed from peft.utils.get_peft_model_state_dict
25
+ def get_peft_state_maybe_zero_3(named_params, bias):
26
+ if bias == "none":
27
+ to_return = {k: t for k, t in named_params if "lora_" in k}
28
+ elif bias == "all":
29
+ to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
30
+ elif bias == "lora_only":
31
+ to_return = {}
32
+ maybe_lora_bias = {}
33
+ lora_bias_names = set()
34
+ for k, t in named_params:
35
+ if "lora_" in k:
36
+ to_return[k] = t
37
+ bias_name = k.split("lora_")[0] + "bias"
38
+ lora_bias_names.add(bias_name)
39
+ elif "bias" in k:
40
+ maybe_lora_bias[k] = t
41
+ for k, t in maybe_lora_bias:
42
+ if bias_name in lora_bias_names:
43
+ to_return[bias_name] = t
44
+ else:
45
+ raise NotImplementedError
46
+ to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
47
+ return to_return
48
+
49
+
50
+ def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
51
+ to_return = {k: t for k, t in named_params if "lora_" not in k}
52
+ if require_grad_only:
53
+ to_return = {k: t for k, t in to_return.items() if t.requires_grad}
54
+ to_return = {
55
+ k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()
56
+ }
57
+ return to_return
58
+
59
+
60
+ def _insert_adapter_name_into_state_dict(
61
+ state_dict: dict[str, torch.Tensor], adapter_name: str, parameter_prefix: str
62
+ ) -> dict[str, torch.Tensor]:
63
+ """Utility function to remap the state_dict keys to fit the PEFT model by inserting the adapter name."""
64
+ peft_model_state_dict = {}
65
+ for key, val in state_dict.items():
66
+ if parameter_prefix in key:
67
+ suffix = key.split(parameter_prefix)[1]
68
+ if "." in suffix:
69
+ suffix_to_replace = ".".join(suffix.split(".")[1:])
70
+ key = key.replace(
71
+ suffix_to_replace, f"{adapter_name}.{suffix_to_replace}"
72
+ )
73
+ else:
74
+ key = f"{key}.{adapter_name}"
75
+ peft_model_state_dict[key] = val
76
+ else:
77
+ peft_model_state_dict[key] = val
78
+ return peft_model_state_dict
79
+
80
+
81
+ def save_video(tensor, path):
82
+ from torchvision.io import write_video
83
+
84
+ tensor = tensor * 255.0
85
+ tensor = tensor.permute(0, 2, 3, 1)
86
+ tensor = tensor.clamp(0, 255).byte()
87
+ write_video(path, tensor, 4, video_codec="h264")
88
+
89
+
90
+ def load_model_from_checkpoint(model, checkpoint_dir, checkpoint_step):
91
+ checkpoint_paths = glob.glob(os.path.join(checkpoint_dir, "checkpoint-*"))
92
+ checkpoint_paths.sort(key=lambda x: int(x.split("-")[-1]), reverse=True)
93
+
94
+ if checkpoint_step is None or checkpoint_step == -1:
95
+ # get the latest checkpoint
96
+ checkpoint_path = checkpoint_paths[0]
97
+ print(
98
+ f"===> Checkpoint step is not provided, using the latest checkpoint: {checkpoint_path}"
99
+ )
100
+ else:
101
+ checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint-{checkpoint_step}")
102
+ if checkpoint_path not in checkpoint_paths:
103
+ checkpoint_path = checkpoint_paths[0]
104
+ print(
105
+ f"===> Checkpoint step {checkpoint_step} not found, using the latest checkpoint: {checkpoint_path}"
106
+ )
107
+ else:
108
+ print(
109
+ f"===> Checkpoint step {checkpoint_step} found, using the specified checkpoint: {checkpoint_path}"
110
+ )
111
+
112
+ checkpoint_step = checkpoint_path.split("checkpoint-")[-1].split("/")[0]
113
+
114
+ full_ckpt = os.path.join(checkpoint_path, "model.pth")
115
+ lora_ckpt = os.path.join(checkpoint_path, "adapter_model.safetensors")
116
+ non_lora_ckpt = os.path.join(checkpoint_path, "non_lora_state_dict.pth")
117
+ if os.path.exists(full_ckpt):
118
+ model_state_dict = torch.load(full_ckpt, map_location="cpu")
119
+ model.load_state_dict(model_state_dict)
120
+ else:
121
+ lora_state_dict = safetensors.torch.load_file(lora_ckpt)
122
+ non_lora_state_dict = torch.load(non_lora_ckpt, map_location="cpu")
123
+
124
+ lora_state_dict = _insert_adapter_name_into_state_dict(
125
+ lora_state_dict, adapter_name="default", parameter_prefix="lora_"
126
+ )
127
+
128
+ model_state_dict = model.state_dict()
129
+ model_state_dict.update(non_lora_state_dict)
130
+ model_state_dict.update(lora_state_dict)
131
+ model.load_state_dict(model_state_dict)
132
+
133
+ return model, checkpoint_step
134
+
135
+
136
+ def find_target_linear_names(
137
+ model, num_lora_modules=-1, lora_namespan_exclude=[], verbose=False
138
+ ):
139
+ """
140
+ Find the target linear modules for LoRA.
141
+ """
142
+ linear_cls = torch.nn.Linear
143
+ embedding_cls = torch.nn.Embedding
144
+ lora_module_names = []
145
+
146
+ for name, module in model.named_modules():
147
+ if any(ex_keyword in name for ex_keyword in lora_namespan_exclude):
148
+ # print(f"Excluding module: {name}")
149
+ continue
150
+
151
+ if isinstance(module, (linear_cls, embedding_cls)):
152
+ lora_module_names.append(name)
153
+
154
+ if num_lora_modules > 0:
155
+ lora_module_names = lora_module_names[-num_lora_modules:]
156
+ if verbose:
157
+ print(f"Found {len(lora_module_names)} lora modules: {lora_module_names}")
158
+ return lora_module_names
pretrained_models/download_pretrained_models.sh ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Create pretrained_models directory
4
+ mkdir -p pretrained_models
5
+
6
+ # Model list (using array instead of associative array to avoid ordering issues)
7
+ models=(
8
+ "black-forest-labs/FLUX.1-dev:FLUX1-dev"
9
+ "stabilityai/stable-diffusion-3-medium-diffusers:SD3-medium"
10
+ "stabilityai/stable-diffusion-xl-base-1.0:SDXL-base"
11
+ "Kwai-Kolors/Kolors-diffusers:Kolors"
12
+ "THUDM/CogView4-6B:CogView4"
13
+ "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS:PixArt"
14
+ "Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers:HunyuanDiT"
15
+ "FoundationVision/Infinity"
16
+ "google/flan-t5-xl"
17
+ "Qwen/Qwen2-VL-7B-Instruct: Qwen2-VL-7B-Instruct"
18
+ "Qwen/Qwen2-VL-2B-Instruct: Qwen2-VL-2B-Instruct"
19
+ )
20
+
21
+ # Create tmux session and set up the first window
22
+ model_info="${models[0]}"
23
+ model_path="${model_info%:*}"
24
+ window_name="${model_info#*:}"
25
+ local_dir="${model_path##*/}"
26
+
27
+ # Set the first window name directly when creating session
28
+ tmux new-session -d -s download_pretrained_model -n "$window_name"
29
+
30
+ # Give tmux some time to initialize
31
+ sleep 0.5
32
+
33
+ # Set commands for the first window
34
+ tmux send-keys -t download_pretrained_model:"$window_name" "conda activate hpsv3" Enter
35
+ tmux send-keys -t download_pretrained_model:"$window_name" "export HF_ENDPOINT=https://alpha.hf-mirror.com" Enter
36
+ tmux send-keys -t download_pretrained_model:"$window_name" "cd pretrained_models" Enter
37
+ tmux send-keys -t download_pretrained_model:"$window_name" "while true; do huggingface-cli download $model_path --local-dir $local_dir && break || sleep 60; done" Enter
38
+
39
+ # Create new windows for remaining models
40
+ for i in $(seq 1 $((${#models[@]} - 1))); do
41
+ model_info="${models[$i]}"
42
+ model_path="${model_info%:*}"
43
+ window_name="${model_info#*:}"
44
+ local_dir="${model_path##*/}"
45
+
46
+ # Create new window
47
+ tmux new-window -t download_pretrained_model -n "$window_name"
48
+ # Add small delay to ensure window creation is complete
49
+ sleep 0.2
50
+ tmux send-keys -t download_pretrained_model:"$window_name" "conda activate hpsv3" Enter
51
+ tmux send-keys -t download_pretrained_model:"$window_name" "export HF_ENDPOINT=https://alpha.hf-mirror.com" Enter
52
+ tmux send-keys -t download_pretrained_model:"$window_name" "cd pretrained_models" Enter
53
+ tmux send-keys -t download_pretrained_model:"$window_name" "while true; do huggingface-cli download $model_path --local-dir $local_dir && break || sleep 60; done" Enter
54
+ done
55
+ # Switch to the first window (using the first model's window name)
56
+ first_window_name="${models[0]#*:}"
57
+ tmux select-window -t download_pretrained_model:"$first_window_name"
58
+
59
+ echo "Created tmux session 'download_pretrained_model' and started downloading all models"
60
+ echo "Use 'tmux attach -t download_pretrained_model' to view download progress"
61
+ echo "Use Ctrl+B then press number keys to switch between different download windows"
62
+ echo "Use 'tmux list-windows -t download_pretrained_model' to view all windows"
63
+ echo "Use 'tmux kill-session -t download_pretrained_model' to end the session"
pyproject.toml ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=45", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "hpsv3"
7
+ version = "1.0.0"
8
+ description = "HPSv3: Towards Wide-Spectrum Human Preference Score - A VLM-based preference model for image quality assessment"
9
+ readme = "README.md"
10
+ requires-python = ">=3.8"
11
+ license = {text = "MIT"}
12
+ authors = [
13
+ {name = "Yunhao Shui"},
14
+ {name = "Yuhang Ma"},
15
+ ]
16
+ keywords = ["machine learning", "computer vision", "human preference", "image quality", "VLM", "multimodal"]
17
+ classifiers = [
18
+ "Development Status :: 4 - Beta",
19
+ "Intended Audience :: Developers",
20
+ "Intended Audience :: Science/Research",
21
+ "License :: OSI Approved :: MIT License",
22
+ "Operating System :: OS Independent",
23
+ "Programming Language :: Python :: 3",
24
+ "Programming Language :: Python :: 3.8",
25
+ "Programming Language :: Python :: 3.9",
26
+ "Programming Language :: Python :: 3.10",
27
+ "Programming Language :: Python :: 3.11",
28
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
29
+ "Topic :: Software Development :: Libraries :: Python Modules",
30
+ ]
31
+ dependencies = [
32
+ "torch>=2.0.0",
33
+ "torchvision>=0.15.0",
34
+ "transformers==4.45.2",
35
+ "accelerate>=0.20.0",
36
+ "datasets>=2.10.0",
37
+ "diffusers>=0.20.0",
38
+ "Pillow>=9.0.0",
39
+ "numpy>=1.20.0",
40
+ "tqdm>=4.60.0",
41
+ "pyyaml>=6.0",
42
+ "omegaconf>=2.3.0",
43
+ "opencv-python>=4.5.0",
44
+ "safetensors>=0.3.0",
45
+ "einops>=0.6.0",
46
+ "qwen-vl-utils>=0.0.8",
47
+ "timm>=0.9.0",
48
+ "deepspeed>=0.12.0",
49
+ "peft>=0.8.0",
50
+ "trl>=0.7.0",
51
+ "fire>=0.7.0"
52
+ ]
53
+
54
+ [project.urls]
55
+ Homepage = "https://mizzenai.github.io/HPSv3/"
56
+ Source = "https://github.com/MizzenAI/HPSv3"
57
+ Documentation = "https://github.com/MizzenAI/HPSv3/blob/main/README.md"
58
+ Paper = "https://arxiv.org/abs/2411.07232"
59
+
60
+ [tool.setuptools.packages.find]
61
+ include = ["hpsv3*", "generate*", "evaluate*"]
62
+
63
+ [tool.setuptools.package-data]
64
+ hpsv3 = ["config/*.yaml", "config/ds_config/*.json"]
requirements.txt CHANGED
@@ -190,6 +190,6 @@ widgetsnbextension==4.0.14
190
  xxhash==3.5.0
191
  yarl==1.20.1
192
  zipp==3.22.0
 
193
  hpsv3
194
- hpsv2
195
- # flash-attn==2.7.4.post1
 
190
  xxhash==3.5.0
191
  yarl==1.20.1
192
  zipp==3.22.0
193
+ # flash-attn==2.7.4.post1
194
  hpsv3
195
+ hpsv2