File size: 9,596 Bytes
9b57ce7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
import os
import json
import random
import gc
import argparse
import torch
from PIL import Image
from transformers import AutoProcessor, AutoModel

from generator import Generator
from hpsv3.inference import HPSv3RewardInferencer
from hpsv3.cohp.utils_cohp.pipelines import *
from hpsv3.cohp.utils_cohp.image2image_pipeline import Image2ImagePipeline

try:
    from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer
except:
    print("HPSv2 model not found, skipping HPSv2 related imports.")

try:
    import ImageReward as RM
except:
    print("ImageReward module not found, skipping ImageReward related imports.")


def initialize_hpsv2_model(device, checkpoint_path):
    model_dict = {}
    model, _, preprocess_val = create_model_and_transforms(
        'ViT-H-14',
        'laion2B-s32B-b79K',
        device=device,
        precision='amp',
        pretrained_image=False,
        output_dict=True,
    )
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    model.load_state_dict(checkpoint['state_dict'])
    model = model.to(device).eval()
    tokenizer = get_tokenizer('ViT-H-14')

    model_dict['model'] = model
    model_dict['preprocess_val'] = preprocess_val
    return model_dict, tokenizer


def score_hpsv2(model_dict, tokenizer, device, img_paths, prompts):
    model = model_dict['model']
    preprocess_val = model_dict['preprocess_val']
    images = [preprocess_val(Image.open(p)).unsqueeze(0) for p in img_paths]
    images = torch.cat(images, dim=0).to(device)
    texts = tokenizer(prompts).to(device)
    
    with torch.no_grad():
        outputs = model(images, texts)
        image_features, text_features = outputs["image_features"], outputs["text_features"]
        logits_per_image = image_features @ text_features.T
        hps_scores = torch.diagonal(logits_per_image).cpu()
    return hps_scores


def calculate_pickscore_probs(model, processor, prompt, images, device):
    image_inputs = processor(images=images, padding=True, return_tensors="pt").to(device)
    text_inputs = processor(text=prompt, padding=True, return_tensors="pt").to(device)

    with torch.no_grad():
        image_embs = model.get_image_features(**image_inputs)
        image_embs /= torch.norm(image_embs, dim=-1, keepdim=True)

        text_embs = model.get_text_features(**text_inputs)
        text_embs /= torch.norm(text_embs, dim=-1, keepdim=True)

        scores = text_embs @ image_embs.T
    return scores


def generate_images(
    reward_type, prompt, index, pipeline_params, pipelines_mapping, inferencer,
    output_dir='cohp_output', num_rounds=5, strength=0.8, device='cuda:1'
):
    os.makedirs(output_dir, exist_ok=True)
    result_json_dir = os.path.join(output_dir, 'result_json')
    os.makedirs(result_json_dir, exist_ok=True)

    info_dict = {
        'caption': prompt,
        'width': 1024,
        'height': 1024,
        'aspect_ratio': 1,
        'save_name': f"{index}_origin",
    }
    di_score_pipelines = {}
    intermediate_results_model_pref = {}
    intermediate_results_sample_pref = {}
    max_final_score = 0

    for pipeline_param in pipeline_params:
        generator = Generator(
            device=device,
            pipe_name=pipeline_param.pipeline_name,
            pipe_type=pipeline_param.pipeline_type,
            pipe_init_kwargs=pipeline_param.pipe_init_kwargs,
        )
        image_paths = generator.generate_imgs(
            info_dict=info_dict,
            generation_path=os.path.join(output_dir, pipeline_param.generation_path),
            batch_size=2,
            device=device,
            seed=random.randint(0, 75859066837),
            weight_dtype=pipeline_param.pipe_init_kwargs["torch_dtype"],
            generation_kwargs=pipeline_param.generation_kwargs

        )

        score_list = []
        for image_path in image_paths:
            if reward_type == 'hpsv2':
                score = score_hpsv2(model_dict, tokenizer, device, [image_path], [prompt]).item()
            elif reward_type == 'hpsv3':
                score = inferencer.reward([image_path], [prompt]).cpu().detach()[0][0].item()
            elif reward_type == 'imagereward':
                score = inferencer.score(prompt, [image_path])
            elif reward_type == 'pickscore':
                score = calculate_pickscore_probs(inferencer, processor_pickscore, prompt, [Image.open(image_path)], device)[0][0].item()
            else:
                raise ValueError(f"Unsupported reward type: {reward_type}")
            score_list.append(score)

        average_score = sum(score_list) / len(score_list)
        pipeline_name = pipelines_mapping[pipeline_param]
        di_score_pipelines[pipeline_name] = average_score

        intermediate_results_model_pref[pipeline_name] = {
            'image_paths': image_paths,
            'scores': score_list,
            'max_image_path': image_paths[score_list.index(max(score_list))],
            'max_score': max(score_list),
        }
        generator.pipelines.to("cpu")
        del generator
        torch.cuda.empty_cache()
        gc.collect()

    # Select the best pipeline based on scores
    best_pipeline = max(di_score_pipelines, key=di_score_pipelines.get)
    best_pipeline_results = intermediate_results_model_pref[best_pipeline]
    chosen_image_path = best_pipeline_results['max_image_path']

    # Refinement with Image2ImagePipeline
    i2ipipeline = Image2ImagePipeline(best_pipeline)
    for round_num in range(num_rounds):
        if round_num in [3, 4]:
            strength = 0.5
        images = i2ipipeline.generate_image(
            prompt=prompt,
            image_path=chosen_image_path,
            strength=strength,
            batch_size=4,
            save_prefix=f'{index}_{best_pipeline}_image2image_round{round_num + 1}',
            output_dir=output_dir,
        )

        score_list = []
        for image_path in images:
            if reward_type == 'hpsv2':
                score = score_hpsv2(model_dict, tokenizer, device, [image_path], [prompt]).item()
            elif reward_type == 'hpsv3':
                score = inferencer.reward([image_path], [prompt]).cpu().detach()[0][0].item()
            elif reward_type == 'imagereward':
                score = inferencer.score(prompt, [image_path])
            elif reward_type == 'pickscore':
                score = calculate_pickscore_probs(inferencer, processor_pickscore, prompt, [Image.open(image_path)], device)[0][0].item()
            else:
                raise ValueError(f"Unsupported reward type: {reward_type}")
            score_list.append(score)

        # Update intermediate results
        intermediate_results_sample_pref[round_num + 1] = {
            'image_paths': images,
            'scores': score_list,
            'max_image_path': images[score_list.index(max(score_list))],
            'max_score': max(score_list),
        }

        # Determine best image during refinement
        if max(score_list) > max_final_score:
            max_final_score = max(score_list)
            chosen_image_path = images[score_list.index(max(score_list))]

    # Save final results
    results = {
        'prompt': prompt,
        'best_model': best_pipeline,
        'final_image_path': chosen_image_path,
        'model_preference_info': intermediate_results_model_pref,
        'sample_preference_intermediate_results': intermediate_results_sample_pref,
    }
    with open(os.path.join(result_json_dir, f'{index}.json'), 'w', encoding='utf-8') as file:
        json.dump(results, file, ensure_ascii=False, indent=4)
    return results


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Image Generation Script")
    parser.add_argument('--prompt', type=str, required=True, help='The prompt for image generation')
    parser.add_argument('--index', type=str, required=True, help='Index for saving results')
    parser.add_argument('--device', type=str, default='cuda:1', help='Device to run the model on')
    parser.add_argument('--reward_model', type=str, default='hpsv3', help='Reward model to use (hpsv2, hpsv3, pickscore, or imagereward)')
    args = parser.parse_args()

    # Initialize models and pipelines
    output_dir = f"cohp_output_{args.reward_model}"
    if args.reward_model == 'hpsv2':
        model_dict, tokenizer = initialize_hpsv2_model(args.device, 'pretrained_models/HPS_v2.1_compressed.pt')
        inferencer = model_dict
    elif args.reward_model == 'hpsv3':
        inferencer = HPSv3RewardInferencer(device=args.device)
    elif args.reward_model == 'imagereward':
        inferencer = RM.load("ImageReward-v1.0").to(args.device)
    elif args.reward_model == 'pickscore':
        processor_pickscore = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
        inferencer = AutoModel.from_pretrained("yuvalkirstain/PickScore_v1").eval().to(args.device)
    else:
        raise ValueError("Unsupported reward model.")

    # Define pipelines
    pipeline_params = [kolors_pipe, sd3_medium_pipe, playground_v2_5_pipe, flux_dev_pipe]
    pipelines_mapping = {
        flux_dev_pipe: 'flux',
        kolors_pipe: 'kolors',
        sd3_medium_pipe: 'sd3',
        playground_v2_5_pipe: 'playground_v2_5',
    }

    # Generate images
    results = generate_images(
        reward_type=args.reward_model,
        prompt=args.prompt,
        index=args.index,
        pipeline_params=pipeline_params,
        pipelines_mapping=pipelines_mapping,
        inferencer=inferencer,
        output_dir=output_dir,
        num_rounds=4,
    )