cp524 commited on
Commit
78f8c32
·
1 Parent(s): 2b3be43

Add reward fns

Browse files
src/smc/rewards.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch
3
+ from importlib import resources
4
+ ASSETS_PATH = resources.files("assets")
5
+
6
+ def jpeg_compressibility(inference_dtype=None, device=None):
7
+ import io
8
+ import numpy as np
9
+ def loss_fn(images):
10
+ if images.min() < 0: # normalize unnormalized images
11
+ images = ((images / 2) + 0.5).clamp(0, 1)
12
+ if isinstance(images, torch.Tensor):
13
+ images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
14
+ images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
15
+ images = [Image.fromarray(image) for image in images]
16
+ buffers = [io.BytesIO() for _ in images]
17
+ for image, buffer in zip(images, buffers):
18
+ image.save(buffer, format="JPEG", quality=95)
19
+ sizes = [buffer.tell() / 1000 for buffer in buffers]
20
+ loss = torch.tensor(sizes, dtype=inference_dtype, device=device)
21
+ rewards = -1 * loss
22
+
23
+ return loss, rewards
24
+
25
+ return loss_fn
26
+
27
+ def clip_score(
28
+ inference_dtype=None,
29
+ device=None,
30
+ return_loss=False,
31
+ ):
32
+ from src.smc.scorers.clip_scorer import CLIPScorer
33
+
34
+ scorer = CLIPScorer(dtype=torch.float32, device=device)
35
+ scorer.requires_grad_(False)
36
+
37
+ if not return_loss:
38
+ def _fn(images, prompts):
39
+ if images.min() < 0: # normalize unnormalized images
40
+ images = ((images / 2) + 0.5).clamp(0, 1)
41
+ scores = scorer(images, prompts)
42
+ return scores
43
+
44
+ return _fn
45
+
46
+ else:
47
+ def loss_fn(images, prompts):
48
+ if images.min() < 0: # normalize unnormalized images
49
+ images = ((images / 2) + 0.5).clamp(0, 1)
50
+ scores = scorer(images, prompts)
51
+
52
+ loss = - scores
53
+ return loss, scores
54
+
55
+ return loss_fn
56
+
57
+ def aesthetic_score(
58
+ torch_dtype=None,
59
+ aesthetic_target=None,
60
+ grad_scale=0,
61
+ device=None,
62
+ return_loss=False,
63
+ ):
64
+ from src.smc.scorers.aesthetic_scorer import AestheticScorer
65
+
66
+ scorer = AestheticScorer(dtype=torch.float32, device=device)
67
+ scorer.requires_grad_(False)
68
+
69
+ if not return_loss:
70
+ def _fn(images, prompts):
71
+ if images.min() < 0: # normalize unnormalized images
72
+ images = ((images / 2) + 0.5).clamp(0, 1)
73
+ scores = scorer(images)
74
+ return scores
75
+
76
+ return _fn
77
+
78
+ else:
79
+ def loss_fn(images, prompts):
80
+ if images.min() < 0: # normalize unnormalized images
81
+ images = ((images / 2) + 0.5).clamp(0, 1)
82
+ scores = scorer(images)
83
+
84
+ if aesthetic_target is None: # default maximization
85
+ loss = -1 * scores
86
+ else:
87
+ # using L1 to keep on same scale
88
+ loss = abs(scores - aesthetic_target)
89
+ return loss * grad_scale, scores
90
+
91
+ return loss_fn
92
+
93
+
94
+ def hps_score(
95
+ inference_dtype=None,
96
+ device=None,
97
+ return_loss=False,
98
+ ):
99
+ from src.smc.scorers.hpsv2_scorer import HPSv2Scorer
100
+
101
+ scorer = HPSv2Scorer(dtype=torch.float32, device=device)
102
+ scorer.requires_grad_(False)
103
+
104
+ if not return_loss:
105
+ def _fn(images, prompts):
106
+ if images.min() < 0: # normalize unnormalized images
107
+ images = ((images / 2) + 0.5).clamp(0, 1)
108
+ scores = scorer(images, prompts)
109
+ return scores
110
+
111
+ return _fn
112
+
113
+ else:
114
+ def loss_fn(images, prompts):
115
+ if images.min() < 0: # normalize unnormalized images
116
+ images = ((images / 2) + 0.5).clamp(0, 1)
117
+ scores = scorer(images, prompts)
118
+
119
+ loss = 1.0 - scores
120
+ return loss, scores
121
+
122
+ return loss_fn
123
+
124
+
125
+ def ImageReward(
126
+ inference_dtype=None,
127
+ device=None,
128
+ return_loss=False,
129
+ ):
130
+ from src.smc.scorers.ImageReward_scorer import ImageRewardScorer
131
+
132
+ scorer = ImageRewardScorer(dtype=torch.float32, device=device)
133
+ scorer.requires_grad_(False)
134
+
135
+ if not return_loss:
136
+ def _fn(images, prompts):
137
+ if images.min() < 0: # normalize unnormalized images
138
+ images = ((images / 2) + 0.5).clamp(0, 1)
139
+ scores = scorer(images, prompts)
140
+ return scores
141
+
142
+ return _fn
143
+
144
+ else:
145
+ def loss_fn(images, prompts):
146
+ if images.min() < 0: # normalize unnormalized images
147
+ images = ((images / 2) + 0.5).clamp(0, 1)
148
+ scores = scorer(images, prompts)
149
+
150
+ loss = - scores
151
+ return loss, scores
152
+
153
+ return loss_fn
154
+
155
+
156
+ def ImageReward_Fk_Steering(
157
+ inference_dtype=None,
158
+ device=None,
159
+ return_loss=False,
160
+ bias=None,
161
+ ):
162
+ from src.smc.scorers.image_reward_utils import rm_load
163
+
164
+ scorer = rm_load("ImageReward-v1.0")
165
+
166
+ if not return_loss:
167
+ def _fn(images, prompts):
168
+ if images.min() < 0: # normalize unnormalized images
169
+ images = ((images / 2) + 0.5).clamp(0, 1)
170
+ scores = scorer.score_batched(prompts, images)
171
+ if bias:
172
+ scores += bias
173
+ return scores
174
+
175
+ return _fn
176
+
177
+ else:
178
+ def loss_fn(images, prompts):
179
+ if images.min() < 0: # normalize unnormalized images
180
+ images = ((images / 2) + 0.5).clamp(0, 1)
181
+ scores = scorer.score_batched(prompts, images)
182
+
183
+ loss = - scores
184
+ return loss, scores
185
+
186
+ return loss_fn
187
+
188
+
189
+ def PickScore(
190
+ inference_dtype=None,
191
+ device=None,
192
+ return_loss=False,
193
+ ):
194
+ from src.smc.scorers.PickScore_scorer import PickScoreScorer
195
+
196
+ scorer = PickScoreScorer(dtype=torch.float32, device=device)
197
+ scorer.requires_grad_(False)
198
+
199
+ if not return_loss:
200
+ def _fn(images, prompts):
201
+ # from src.plot_utils import save_batch_images
202
+ # save_batch_images(images, "output_SMC")
203
+ if images.min() < 0: # normalize unnormalized images
204
+ images = ((images / 2) + 0.5).clamp(0, 1)
205
+ scores = scorer(images, prompts)
206
+ return scores
207
+
208
+ return _fn
209
+
210
+ else:
211
+ def loss_fn(images, prompts):
212
+ if images.min() < 0: # normalize unnormalized images
213
+ images = ((images / 2) + 0.5).clamp(0, 1)
214
+ scores = scorer(images, prompts)
215
+
216
+ loss = - scores
217
+ return loss, scores
218
+
219
+ return loss_fn
220
+
221
+
222
+ def color_match_reward(x: torch.Tensor, target_color: torch.Tensor) -> torch.Tensor:
223
+ """
224
+ Reward images whose *mean* RGB comes close to a given target color.
225
+
226
+ Args:
227
+ x : [B, 3, H, W] float images (e.g. in [0,1] or [0,255])
228
+ target_color : [3] float tensor with your desired RGB mean
229
+
230
+ Returns:
231
+ reward : [B] higher when image mean-color ≈ target_color
232
+ """
233
+ B, C, H, W = x.shape
234
+ # compute per-image mean color vector [B,3]
235
+ mean_color = x.view(B, C, -1).mean(dim=2)
236
+
237
+ # squared distance in RGB space
238
+ dist2 = (mean_color - target_color[None, :].to(x.device)).pow(2).sum(dim=1)
239
+
240
+ # negative distance = higher reward for closer color
241
+ return -dist2
src/smc/scorers/ImageReward_scorer.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers import CLIPProcessor
5
+ from ImageReward.models.BLIP.blip_pretrain import BLIP_Pretrain
6
+ from ImageReward import ImageReward_download
7
+
8
+
9
+ class MLP(nn.Module):
10
+ def __init__(self):
11
+ super().__init__()
12
+ self.layers = nn.Sequential(
13
+ nn.Linear(768, 1024),
14
+ nn.Dropout(0.2),
15
+ nn.Linear(1024, 128),
16
+ nn.Dropout(0.2),
17
+ nn.Linear(128, 64),
18
+ nn.Dropout(0.1),
19
+ nn.Linear(64, 16),
20
+ nn.Linear(16, 1),
21
+ )
22
+
23
+ @torch.no_grad()
24
+ def forward(self, embed):
25
+ return self.layers(embed)
26
+
27
+
28
+ class ImageRewardScorer(nn.Module):
29
+ def __init__(self, dtype, device):
30
+ super().__init__()
31
+ self.dtype = dtype
32
+ self.device = device
33
+
34
+ self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
35
+
36
+ download_root = "/vol/bitbucket/cp524/cache/ImageReward"
37
+ config_path = ImageReward_download("https://huggingface.co/THUDM/ImageReward/blob/main/med_config.json", download_root)
38
+ model_path = ImageReward_download("https://huggingface.co/THUDM/ImageReward/blob/main/ImageReward.pt", download_root)
39
+ # config_path = os.path.join(download_root, "med_config.json")
40
+ # model_path = os.path.join(download_root, "ImageReward.pt")
41
+
42
+ self.blip = BLIP_Pretrain(image_size=224, vit='large', med_config=config_path).to(self.device, dtype=self.dtype)
43
+ self.mlp = MLP().to(self.device, dtype=self.dtype)
44
+
45
+ state_dict = torch.load(model_path, map_location=self.device)
46
+ self.load_state_dict(state_dict, strict=False)
47
+ self.eval()
48
+
49
+ @torch.no_grad()
50
+ def __call__(self, images, prompts):
51
+ images = (images * 255).round().clamp(0, 255).to(torch.uint8)
52
+ inputs = self.processor(images=images, return_tensors="pt")
53
+ inputs = {k: v.to(self.dtype).to(self.device) for k, v in inputs.items()}["pixel_values"]
54
+ image_embeds = self.blip.visual_encoder(inputs)
55
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
56
+ text_input = self.blip.tokenizer(
57
+ prompts,
58
+ padding='max_length',
59
+ truncation=True,
60
+ max_length=35,
61
+ return_tensors="pt"
62
+ ).to(self.device)
63
+ text_output = self.blip.text_encoder(
64
+ text_input.input_ids,
65
+ attention_mask=text_input.attention_mask,
66
+ encoder_hidden_states=image_embeds,
67
+ encoder_attention_mask=image_atts,
68
+ return_dict=True,
69
+ )
70
+ txt_features = text_output.last_hidden_state[:, 0, :].to(dtype=self.dtype)
71
+ scores = self.mlp(txt_features).squeeze(1)
72
+
73
+ return scores
src/smc/scorers/PickScore_scorer.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from transformers import AutoModel, CLIPProcessor
4
+ import torchvision
5
+
6
+
7
+ class PickScoreScorer(torch.nn.Module):
8
+ def __init__(self, dtype, device):
9
+ super().__init__()
10
+ self.dtype = dtype
11
+ self.device = device
12
+
13
+ self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
14
+
15
+ checkpoint_path = "yuvalkirstain/PickScore_v1"
16
+ # checkpoint_path = f"{os.path.expanduser('~')}/.cache/PickScore_v1"
17
+ self.model = AutoModel.from_pretrained(checkpoint_path).eval().to(self.device, dtype=self.dtype)
18
+
19
+ self.target_size = 224
20
+ self.normalize = torchvision.transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
21
+ std=[0.26862954, 0.26130258, 0.27577711])
22
+
23
+ def __call__(self, images, prompts):
24
+ text_inputs = self.processor(
25
+ text=prompts,
26
+ padding=True,
27
+ truncation=True,
28
+ max_length=77,
29
+ return_tensors="pt",
30
+ ).to(self.device)
31
+ text_embeds = self.model.get_text_features(**text_inputs)
32
+ text_embeds = text_embeds / torch.norm(text_embeds, dim=-1, keepdim=True)
33
+
34
+ inputs = torchvision.transforms.Resize(self.target_size)(images)
35
+ inputs = self.normalize(inputs).to(self.dtype)
36
+ image_embeds = self.model.get_image_features(pixel_values=inputs)
37
+ image_embeds = image_embeds / torch.norm(image_embeds, dim=-1, keepdim=True)
38
+ logits_per_image = image_embeds @ text_embeds.T
39
+ scores = torch.diagonal(logits_per_image)
40
+
41
+ return scores
src/smc/scorers/__init__.py ADDED
File without changes
src/smc/scorers/aesthetic_scorer.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from importlib import resources
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers import CLIPModel, CLIPProcessor
5
+ import torchvision
6
+
7
+ ASSETS_PATH = resources.files("assets")
8
+
9
+
10
+ class MLP(nn.Module):
11
+ def __init__(self):
12
+ super().__init__()
13
+ self.layers = nn.Sequential(
14
+ nn.Linear(768, 1024),
15
+ nn.Dropout(0.2),
16
+ nn.Linear(1024, 128),
17
+ nn.Dropout(0.2),
18
+ nn.Linear(128, 64),
19
+ nn.Dropout(0.1),
20
+ nn.Linear(64, 16),
21
+ nn.Linear(16, 1),
22
+ )
23
+
24
+ def forward(self, embed):
25
+ return self.layers(embed)
26
+
27
+
28
+ class AestheticScorer(nn.Module):
29
+ def __init__(self, dtype, device):
30
+ super().__init__()
31
+ self.dtype = dtype
32
+ self.device = device
33
+
34
+ self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
35
+
36
+ self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(self.device, dtype=self.dtype)
37
+ self.mlp = MLP().to(self.device, dtype=self.dtype)
38
+
39
+ state_dict = torch.load(ASSETS_PATH.joinpath("sac+logos+ava1-l14-linearMSE.pth"), map_location=self.device)
40
+ self.mlp.load_state_dict(state_dict)
41
+
42
+ self.target_size = 224
43
+ self.normalize = torchvision.transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
44
+ std=[0.26862954, 0.26130258, 0.27577711])
45
+
46
+ self.eval()
47
+
48
+ def __call__(self, images):
49
+ inputs = torchvision.transforms.Resize(self.target_size)(images)
50
+ inputs = self.normalize(inputs).to(self.dtype)
51
+ embed = self.clip.get_image_features(pixel_values=inputs)
52
+ embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True)
53
+
54
+ return self.mlp(embed).squeeze(1)
src/smc/scorers/clip_scorer.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from transformers import CLIPProcessor, CLIPModel
4
+ import torchvision
5
+
6
+
7
+ class CLIPScorer(torch.nn.Module):
8
+ def __init__(self, dtype, device):
9
+ super().__init__()
10
+ self.dtype = dtype
11
+ self.device = device
12
+
13
+ self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
14
+
15
+ self.model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(self.device, dtype=self.dtype)
16
+
17
+ self.target_size = 224
18
+ self.normalize = torchvision.transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
19
+ std=[0.26862954, 0.26130258, 0.27577711])
20
+
21
+ def __call__(self, images, prompts):
22
+ text_inputs = self.processor(
23
+ text=prompts,
24
+ padding=True,
25
+ truncation=True,
26
+ max_length=77,
27
+ return_tensors="pt",
28
+ ).to(self.device)
29
+ []
30
+ text_embeds = self.model.get_text_features(**text_inputs)
31
+ text_embeds = text_embeds / torch.norm(text_embeds, dim=-1, keepdim=True)
32
+
33
+ inputs = torchvision.transforms.Resize(self.target_size)(images)
34
+ inputs = self.normalize(inputs).to(self.dtype)
35
+
36
+ image_embeds = self.model.get_image_features(pixel_values=inputs)
37
+ image_embeds = image_embeds / torch.norm(image_embeds, dim=-1, keepdim=True)
38
+ logits_per_image = image_embeds @ text_embeds.T
39
+ scores = torch.diagonal(logits_per_image)
40
+
41
+ return scores
src/smc/scorers/hpsv2_scorer.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from transformers import CLIPProcessor
4
+ import hpsv2
5
+ from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer
6
+
7
+
8
+ class HPSv2Scorer(torch.nn.Module):
9
+ def __init__(self, dtype, device):
10
+ super().__init__()
11
+ self.dtype = dtype
12
+ self.device = device
13
+
14
+ self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
15
+
16
+ self.model, _, _ = create_model_and_transforms(
17
+ 'ViT-H-14',
18
+ 'laion2B-s32B-b79K',
19
+ precision=self.dtype,
20
+ device=self.device,
21
+ jit=False,
22
+ force_quick_gelu=False,
23
+ force_custom_text=False,
24
+ force_patch_dropout=False,
25
+ force_image_size=None,
26
+ pretrained_image=False,
27
+ image_mean=None,
28
+ image_std=None,
29
+ light_augmentation=True,
30
+ aug_cfg={},
31
+ output_dict=True,
32
+ with_score_predictor=False,
33
+ with_region_predictor=False
34
+ )
35
+
36
+ checkpoint_path = f"{os.path.expanduser('~')}/.cache/huggingface/hub/models--xswu--HPSv2/snapshots/697403c78157020a1ae59d23f111aa58ced35b0a/HPS_v2_compressed.pt"
37
+ # force download of model via score
38
+ hpsv2.score([], "")
39
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
40
+ self.model.load_state_dict(checkpoint['state_dict'])
41
+ self.tokenizer = get_tokenizer('ViT-H-14')
42
+ self.model = self.model.to(self.device, dtype=self.dtype)
43
+ self.model.eval()
44
+
45
+ @torch.no_grad()
46
+ def __call__(self, images, prompts):
47
+ images = (images * 255).round().clamp(0, 255).to(torch.uint8)
48
+ inputs = self.processor(images=images, return_tensors="pt")
49
+ inputs = {k: v.to(self.dtype).to(self.device) for k, v in inputs.items()}["pixel_values"]
50
+ text = self.tokenizer(prompts).to(self.device)
51
+ outputs = self.model(inputs, text)
52
+ image_features, text_features = outputs["image_features"], outputs["text_features"]
53
+ logits_per_image = image_features @ text_features.T
54
+ scores = torch.diagonal(logits_per_image)
55
+
56
+ return scores
src/smc/scorers/image_reward_utils.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+ import os
3
+ import torch
4
+
5
+ from PIL import Image
6
+ import ImageReward as RM
7
+
8
+
9
+ '''
10
+ @File : ImageReward.py
11
+ @Time : 2023/01/28 19:53:00
12
+ @Auther : Jiazheng Xu
13
+ @Contact : xjz22@mails.tsinghua.edu.cn
14
+ @Description: ImageReward Reward model.
15
+ * Based on CLIP code base and improved-aesthetic-predictor code base
16
+ * https://github.com/openai/CLIP
17
+ * https://github.com/christophschuhmann/improved-aesthetic-predictor
18
+ '''
19
+
20
+ import os
21
+ import torch
22
+ import torch.nn as nn
23
+ from PIL import Image
24
+ from ImageReward.models.BLIP.blip_pretrain import BLIP_Pretrain
25
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
26
+
27
+ from torchvision.transforms.functional import pil_to_tensor
28
+
29
+ try:
30
+ from torchvision.transforms import InterpolationMode
31
+
32
+ BICUBIC = InterpolationMode.BICUBIC
33
+ except ImportError:
34
+ BICUBIC = Image.BICUBIC
35
+
36
+
37
+ def _convert_image_to_rgb(image):
38
+ return image.convert("RGB")
39
+
40
+
41
+ def _transform(n_px):
42
+ return Compose(
43
+ [
44
+ Resize(n_px, interpolation=BICUBIC),
45
+ CenterCrop(n_px),
46
+ # _convert_image_to_rgb,
47
+ # ToTensor(),
48
+ Normalize(
49
+ (0.48145466, 0.4578275, 0.40821073),
50
+ (0.26862954, 0.26130258, 0.27577711),
51
+ ),
52
+ ]
53
+ )
54
+
55
+
56
+ class MLP(nn.Module):
57
+ def __init__(self, input_size):
58
+ super().__init__()
59
+ self.input_size = input_size
60
+
61
+ self.layers = nn.Sequential(
62
+ nn.Linear(self.input_size, 1024),
63
+ # nn.ReLU(),
64
+ nn.Dropout(0.2),
65
+ nn.Linear(1024, 128),
66
+ # nn.ReLU(),
67
+ nn.Dropout(0.2),
68
+ nn.Linear(128, 64),
69
+ # nn.ReLU(),
70
+ nn.Dropout(0.1),
71
+ nn.Linear(64, 16),
72
+ # nn.ReLU(),
73
+ nn.Linear(16, 1),
74
+ )
75
+
76
+ # initial MLP param
77
+ for name, param in self.layers.named_parameters():
78
+ if 'weight' in name:
79
+ nn.init.normal_(param, mean=0.0, std=1.0 / (self.input_size + 1))
80
+ if 'bias' in name:
81
+ nn.init.constant_(param, val=0)
82
+
83
+ def forward(self, input):
84
+ return self.layers(input)
85
+
86
+
87
+ class IRSMC(nn.Module):
88
+ def __init__(self, med_config, device='cpu'):
89
+ super().__init__()
90
+ self.device = device
91
+
92
+ self.blip = BLIP_Pretrain(image_size=224, vit='large', med_config=med_config)
93
+ self.preprocess = _transform(224)
94
+ self.mlp = MLP(768)
95
+
96
+ self.mean = 0.16717362830052426
97
+ self.std = 1.0333394966054072
98
+
99
+ def score_batched_old(self, prompts, images):
100
+ # batch
101
+ results = []
102
+ for i, prompt in enumerate(prompts):
103
+ results.append(self.score(prompt, images[i]))
104
+
105
+ return results
106
+
107
+ def score_gard(self, prompt_ids, prompt_attention_mask, image):
108
+ image_embeds = self.blip.visual_encoder(image)
109
+ # text encode cross attention with image
110
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
111
+ self.device
112
+ )
113
+ text_output = self.blip.text_encoder(
114
+ prompt_ids,
115
+ attention_mask=prompt_attention_mask,
116
+ encoder_hidden_states=image_embeds,
117
+ encoder_attention_mask=image_atts,
118
+ return_dict=True,
119
+ )
120
+
121
+ txt_features = text_output.last_hidden_state[:, 0, :] # (feature_dim)
122
+ rewards = self.mlp(txt_features)
123
+ rewards = (rewards - self.mean) / self.std
124
+
125
+ return rewards
126
+
127
+ def score(self, prompt, image):
128
+ if type(image).__name__ == 'list':
129
+ _, rewards = self.inference_rank(prompt, image)
130
+ return rewards
131
+
132
+ # text encode
133
+ text_input = self.blip.tokenizer(
134
+ prompt,
135
+ padding='max_length',
136
+ truncation=True,
137
+ max_length=35,
138
+ return_tensors="pt",
139
+ ).to(self.device)
140
+
141
+ # image encode
142
+ if isinstance(image, Image.Image):
143
+ pil_image = image
144
+ elif isinstance(image, str) and os.path.isfile(image):
145
+ pil_image = Image.open(image)
146
+ else:
147
+ raise TypeError(
148
+ r'This image parameter type has not been supportted yet. Please pass PIL.Image or file path str.'
149
+ )
150
+
151
+ image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
152
+ image_embeds = self.blip.visual_encoder(image)
153
+
154
+ # text encode cross attention with image
155
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
156
+ self.device
157
+ )
158
+ text_output = self.blip.text_encoder(
159
+ text_input.input_ids,
160
+ attention_mask=text_input.attention_mask,
161
+ encoder_hidden_states=image_embeds,
162
+ encoder_attention_mask=image_atts,
163
+ return_dict=True,
164
+ )
165
+
166
+ txt_features = text_output.last_hidden_state[:, 0, :].float() # (feature_dim)
167
+ rewards = self.mlp(txt_features)
168
+ rewards = (rewards - self.mean) / self.std
169
+
170
+ return rewards.detach().cpu().numpy().item()
171
+
172
+ def score_batched(self, prompts, images):
173
+ assert isinstance(prompts, list)
174
+ assert isinstance(images, list) or isinstance(images, torch.Tensor)
175
+
176
+ # text encode
177
+ text_input = self.blip.tokenizer(
178
+ prompts,
179
+ padding='max_length',
180
+ truncation=True,
181
+ max_length=35,
182
+ return_tensors="pt",
183
+ ).to(self.device)
184
+
185
+ # image encode
186
+ images = [
187
+ self.preprocess(image).unsqueeze(0).to(self.device) for image in images
188
+ ]
189
+ images = torch.cat(images, 0).to(torch.float32).to(self.device)
190
+
191
+ image_embeds = self.blip.visual_encoder(images)
192
+
193
+ # text encode cross attention with image
194
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
195
+ self.device
196
+ )
197
+ text_output = self.blip.text_encoder(
198
+ text_input.input_ids,
199
+ attention_mask=text_input.attention_mask,
200
+ encoder_hidden_states=image_embeds,
201
+ encoder_attention_mask=image_atts,
202
+ return_dict=True,
203
+ )
204
+
205
+ txt_features = text_output.last_hidden_state[:, 0, :].float() # (feature_dim)
206
+ rewards = self.mlp(txt_features)
207
+ rewards = (rewards - self.mean) / self.std
208
+
209
+ return rewards.view(txt_features.shape[0])
210
+
211
+ def inference_rank(self, prompt, generations_list):
212
+ text_input = self.blip.tokenizer(
213
+ prompt,
214
+ padding='max_length',
215
+ truncation=True,
216
+ max_length=35,
217
+ return_tensors="pt",
218
+ ).to(self.device)
219
+ txt_set = []
220
+ for generation in generations_list:
221
+ # image encode
222
+ if isinstance(generation, Image.Image):
223
+ pil_image = generation
224
+ elif isinstance(generation, str):
225
+ if os.path.isfile(generation):
226
+ pil_image = Image.open(generation)
227
+ else:
228
+ raise TypeError(
229
+ r'This image parameter type has not been supportted yet. Please pass PIL.Image or file path str.'
230
+ )
231
+ image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
232
+ image_embeds = self.blip.visual_encoder(image)
233
+
234
+ # text encode cross attention with image
235
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
236
+ self.device
237
+ )
238
+ text_output = self.blip.text_encoder(
239
+ text_input.input_ids,
240
+ attention_mask=text_input.attention_mask,
241
+ encoder_hidden_states=image_embeds,
242
+ encoder_attention_mask=image_atts,
243
+ return_dict=True,
244
+ )
245
+ txt_set.append(text_output.last_hidden_state[:, 0, :])
246
+
247
+ txt_features = torch.cat(txt_set, 0).float() # [image_num, feature_dim]
248
+ rewards = self.mlp(txt_features) # [image_num, 1]
249
+ rewards = (rewards - self.mean) / self.std
250
+ rewards = torch.squeeze(rewards)
251
+ _, rank = torch.sort(rewards, dim=0, descending=True)
252
+ _, indices = torch.sort(rank, dim=0)
253
+ indices = indices + 1
254
+
255
+ return (
256
+ indices.detach().cpu().numpy().tolist(),
257
+ rewards.detach().cpu().numpy().tolist(),
258
+ )
259
+
260
+
261
+ def rm_load(
262
+ name: str = "ImageReward-v1.0",
263
+ device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
264
+ download_root: str = None,
265
+ med_config: str = None,
266
+ ):
267
+ """Load a ImageReward model
268
+
269
+ Parameters
270
+ ----------
271
+ name : str
272
+ A model name listed by `ImageReward.available_models()`, or the path to a model checkpoint containing the state_dict
273
+
274
+ device : Union[str, torch.device]
275
+ The device to put the loaded model
276
+
277
+ download_root: str
278
+ path to download the model files; by default, it uses "~/.cache/ImageReward"
279
+
280
+ Returns
281
+ -------
282
+ model : torch.nn.Module
283
+ The ImageReward model
284
+ """
285
+ if name in RM.utils._MODELS:
286
+ model_path = RM.ImageReward_download(
287
+ RM.utils._MODELS[name],
288
+ download_root or os.path.expanduser("~/.cache/ImageReward"),
289
+ )
290
+ elif os.path.isfile(name):
291
+ model_path = name
292
+ else:
293
+ raise RuntimeError(f"Model {name} not found;")
294
+
295
+ print('load checkpoint from %s' % model_path)
296
+ state_dict = torch.load(model_path, map_location='cpu')
297
+ # state_dict = torch.load(model_path, map_location=device)
298
+
299
+ # med_config
300
+ if med_config is None:
301
+ med_config = RM.ImageReward_download(
302
+ "https://huggingface.co/THUDM/ImageReward/blob/main/med_config.json",
303
+ download_root or os.path.expanduser("~/.cache/ImageReward"),
304
+ )
305
+
306
+ model = IRSMC(device=device, med_config=med_config).to(device)
307
+ msg = model.load_state_dict(state_dict, strict=False)
308
+ print("checkpoint loaded")
309
+ model.eval()
310
+ # import pdb; pdb.set_trace()
311
+ return model